Skip to main content

Grove/WASM/
HostBridge.rs

1//! Host Bridge
2//!
3//! Provides bidirectional communication between the host (Grove) and WASM
4//! modules. Handles function calls, data transfer, and marshalling between the
5//! two environments.
6
7use std::{collections::HashMap, sync::Arc};
8
9use anyhow::Result;
10use bytes::Bytes;
11use serde::{Serialize, de::DeserializeOwned};
12use tokio::sync::{RwLock, mpsc, oneshot};
13#[allow(unused_imports)]
14use wasmtime::{Caller, Extern, Func, Linker, Store};
15
16use crate::dev_log;
17
18/// Host bridge error types
19#[derive(Debug, thiserror::Error)]
20pub enum BridgeError {
21	/// Function not found error
22	#[error("Function not found: {0}")]
23	FunctionNotFound(String),
24
25	/// Invalid function signature error
26	#[error("Invalid function signature: {0}")]
27	InvalidSignature(String),
28
29	/// Serialization failed error
30	#[error("Serialization failed: {0}")]
31	SerializationError(String),
32
33	/// Deserialization failed error
34	#[error("Deserialization failed: {0}")]
35	DeserializationError(String),
36
37	/// Host function error
38	#[error("Host function error: {0}")]
39	HostFunctionError(String),
40
41	/// Communication timeout error
42	#[error("Communication timeout")]
43	Timeout,
44
45	/// Bridge closed error
46	#[error("Bridge closed")]
47	BridgeClosed,
48}
49
50/// Type-safe result for operations
51pub type BridgeResult<T> = Result<T, BridgeError>;
52
53/// Function signature information
54#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
55pub struct FunctionSignature {
56	/// Function name
57	pub name:String,
58
59	/// Parameter types
60	pub param_types:Vec<ParamType>,
61
62	/// Return type
63	pub return_type:Option<ReturnType>,
64
65	/// Whether this is an async function
66	pub is_async:bool,
67}
68
69/// Parameter types for WASM functions
70#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
71pub enum ParamType {
72	/// 32-bit signed integer parameter
73	I32,
74
75	/// 64-bit signed integer parameter
76	I64,
77
78	/// 32-bit floating point parameter
79	F32,
80
81	/// 64-bit floating point parameter
82	F64,
83
84	/// Pointer to memory
85	Ptr,
86
87	/// Length parameter following a pointer
88	Len,
89}
90
91/// Return types for WASM functions
92#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
93pub enum ReturnType {
94	/// 32-bit signed integer return type
95	I32,
96
97	/// 64-bit signed integer return type
98	I64,
99
100	/// 32-bit floating point return type
101	F32,
102
103	/// 64-bit floating point return type
104	F64,
105
106	/// No return value (void)
107	Void,
108}
109
110/// Message sent from WASM to host
111#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
112pub struct HostMessage {
113	/// Message ID for correlation
114	pub message_id:String,
115
116	/// Function name to call
117	pub function:String,
118
119	/// Serialized arguments
120	pub args:Vec<Bytes>,
121
122	/// Callback token for async responses
123	pub callback_token:Option<u64>,
124}
125
126/// Response sent from host to WASM
127#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
128pub struct HostResponse {
129	/// Correlating message ID
130	pub message_id:String,
131
132	/// Success flag
133	pub success:bool,
134
135	/// Response data
136	pub data:Option<Bytes>,
137
138	/// Error message if failed
139	pub error:Option<String>,
140}
141
142/// Callback for async function responses
143#[derive(Clone)]
144pub struct AsyncCallback {
145	/// Sender for transmitting the response
146	sender:Arc<tokio::sync::Mutex<Option<tokio::sync::oneshot::Sender<HostResponse>>>>,
147
148	/// Message ID for correlation
149	message_id:String,
150}
151
152impl std::fmt::Debug for AsyncCallback {
153	fn fmt(&self, f:&mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154		f.debug_struct("AsyncCallback").field("message_id", &self.message_id).finish()
155	}
156}
157
158impl AsyncCallback {
159	/// Send response through the callback
160	pub async fn send(self, response:HostResponse) -> Result<()> {
161		let mut sender_opt = self.sender.lock().await;
162
163		if let Some(sender) = sender_opt.take() {
164			sender.send(response).map_err(|_| BridgeError::BridgeClosed)?;
165
166			Ok(())
167		} else {
168			Err(BridgeError::BridgeClosed.into())
169		}
170	}
171}
172
173/// Message from host to WASM
174#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
175pub struct WASMMessage {
176	/// Target function in WASM
177	pub function:String,
178
179	/// Arguments
180	pub args:Vec<Bytes>,
181}
182
183/// Host function callback type
184pub type HostFunctionCallback = fn(Vec<Bytes>) -> Result<Bytes>;
185
186/// Async host function callback type
187pub type AsyncHostFunctionCallback =
188	fn(Vec<Bytes>) -> Box<dyn std::future::Future<Output = Result<Bytes>> + Send + Unpin>;
189
190/// Host function definition
191#[derive(Debug)]
192pub struct HostFunction {
193	/// Function name
194	pub name:String,
195
196	/// Function signature
197	pub signature:FunctionSignature,
198
199	/// Synchronous callback - not serializable (skip serde derive)
200	#[allow(dead_code)]
201	pub callback:Option<HostFunctionCallback>,
202
203	/// Async callback - not serializable (skip serde derive)
204	#[allow(dead_code)]
205	pub async_callback:Option<AsyncHostFunctionCallback>,
206}
207
208/// Host Bridge for WASM communication
209#[derive(Debug)]
210pub struct HostBridgeImpl {
211	/// Registry of host functions exported to WASM
212	host_functions:Arc<RwLock<HashMap<String, HostFunction>>>,
213
214	/// Channel for receiving messages from WASM
215	wasm_to_host_rx:mpsc::UnboundedReceiver<WASMMessage>,
216
217	/// Channel for sending messages to WASM
218	host_to_wasm_tx:mpsc::UnboundedSender<WASMMessage>,
219
220	/// Active async callbacks
221	async_callbacks:Arc<RwLock<HashMap<u64, AsyncCallback>>>,
222
223	/// Next callback token
224	next_callback_token:Arc<std::sync::atomic::AtomicU64>,
225}
226
227impl HostBridgeImpl {
228	/// Create a new host bridge
229	pub fn new() -> Self {
230		let (_wasm_to_host_tx, wasm_to_host_rx) = mpsc::unbounded_channel();
231
232		let (host_to_wasm_tx, host_to_wasm_rx) = mpsc::unbounded_channel();
233
234		// In a real implementation, we'd need to wire these up properly
235		// For now, we drop the receiver to avoid unused warnings
236		drop(host_to_wasm_rx);
237
238		Self {
239			host_functions:Arc::new(RwLock::new(HashMap::new())),
240
241			wasm_to_host_rx,
242
243			host_to_wasm_tx,
244
245			async_callbacks:Arc::new(RwLock::new(HashMap::new())),
246
247			next_callback_token:Arc::new(std::sync::atomic::AtomicU64::new(0)),
248		}
249	}
250
251	/// Register a host function to be exported to WASM
252	pub async fn register_host_function(
253		&self,
254
255		name:&str,
256
257		signature:FunctionSignature,
258
259		callback:HostFunctionCallback,
260	) -> BridgeResult<()> {
261		dev_log!("wasm", "Registering host function: {}", name);
262
263		let mut functions = self.host_functions.write().await;
264
265		if functions.contains_key(name) {
266			dev_log!("wasm", "warn: host function already registered: {}", name);
267		}
268
269		functions.insert(
270			name.to_string(),
271			HostFunction { name:name.to_string(), signature, callback:Some(callback), async_callback:None },
272		);
273
274		dev_log!("wasm", "Host function registered successfully: {}", name);
275
276		Ok(())
277	}
278
279	/// Register an async host function
280	pub async fn register_async_host_function(
281		&self,
282
283		name:&str,
284
285		signature:FunctionSignature,
286
287		callback:AsyncHostFunctionCallback,
288	) -> BridgeResult<()> {
289		dev_log!("wasm", "Registering async host function: {}", name);
290
291		let mut functions = self.host_functions.write().await;
292
293		functions.insert(
294			name.to_string(),
295			HostFunction { name:name.to_string(), signature, callback:None, async_callback:Some(callback) },
296		);
297
298		dev_log!("wasm", "Async host function registered successfully: {}", name);
299
300		Ok(())
301	}
302
303	/// Call a host function from WASM
304	pub async fn call_host_function(&self, function_name:&str, args:Vec<Bytes>) -> BridgeResult<Bytes> {
305		dev_log!("wasm", "Calling host function: {}", function_name);
306
307		let functions = self.host_functions.read().await;
308
309		let func = functions
310			.get(function_name)
311			.ok_or_else(|| BridgeError::FunctionNotFound(function_name.to_string()))?;
312
313		if let Some(callback) = func.callback {
314			// Synchronous call
315			let result =
316				callback(args).map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
317
318			dev_log!("wasm", "Host function call completed: {}", function_name);
319
320			Ok(result)
321		} else if let Some(async_callback) = func.async_callback {
322			// Async call
323			let future = async_callback(args);
324
325			let result = future
326				.await
327				.map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
328
329			dev_log!("wasm", "Async host function call completed: {}", function_name);
330
331			Ok(result)
332		} else {
333			Err(BridgeError::FunctionNotFound(format!(
334				"No callback for function: {}",
335				function_name
336			)))
337		}
338	}
339
340	/// Send a message to WASM
341	pub async fn send_to_wasm(&self, message:WASMMessage) -> BridgeResult<()> {
342		let function_name = message.function.clone();
343
344		self.host_to_wasm_tx.send(message).map_err(|_| BridgeError::BridgeClosed)?;
345
346		dev_log!("wasm", "Message sent to WASM: {}", function_name);
347
348		Ok(())
349	}
350
351	/// Receive a message from WASM (blocking)
352	pub async fn receive_from_wasm(&mut self) -> Option<WASMMessage> { self.wasm_to_host_rx.recv().await }
353
354	/// Create async callback
355	pub async fn create_async_callback(&self, message_id:String) -> (AsyncCallback, u64) {
356		let token = self.next_callback_token.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
357
358		let (tx, _rx) = oneshot::channel();
359
360		// Create callback with Arc-wrapped sender
361		let callback = AsyncCallback {
362			sender:Arc::new(tokio::sync::Mutex::new(Some(tx))),
363
364			message_id:message_id.clone(),
365		};
366
367		self.async_callbacks.write().await.insert(token, callback.clone());
368
369		(callback, token)
370	}
371
372	/// Get callback by token
373	pub async fn get_callback(&self, token:u64) -> Option<AsyncCallback> {
374		self.async_callbacks.write().await.remove(&token)
375	}
376
377	/// Get all registered host functions
378	pub async fn get_host_functions(&self) -> Vec<String> { self.host_functions.read().await.keys().cloned().collect() }
379
380	/// Unregister a host function
381	pub async fn unregister_host_function(&self, name:&str) -> bool {
382		let mut functions = self.host_functions.write().await;
383
384		let removed = functions.remove(name).is_some();
385
386		if removed {
387			dev_log!("wasm", "Host function unregistered: {}", name);
388		}
389
390		removed
391	}
392
393	/// Clear all registered functions
394	pub async fn clear(&self) {
395		dev_log!("wasm", "Clearing all registered host functions");
396
397		self.host_functions.write().await.clear();
398
399		self.async_callbacks.write().await.clear();
400	}
401}
402
403impl Default for HostBridgeImpl {
404	fn default() -> Self { Self::new() }
405}
406
407/// Utility function to serialize data to Bytes
408pub fn serialize_to_bytes<T:Serialize>(data:&T) -> Result<Bytes> {
409	serde_json::to_vec(data)
410		.map(Bytes::from)
411		.map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
412}
413
414/// Utility function to deserialize Bytes to data
415pub fn deserialize_from_bytes<T:DeserializeOwned>(bytes:&Bytes) -> Result<T> {
416	serde_json::from_slice(bytes).map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))
417}
418
419/// Marshal arguments for WASM function call
420pub fn marshal_args(args:Vec<Bytes>) -> Result<Vec<wasmtime::Val>> {
421	args.iter()
422		.map(|bytes| {
423			let value:serde_json::Value = serde_json::from_slice(bytes)?;
424			match value {
425				serde_json::Value::Number(n) => {
426					if let Some(i) = n.as_i64() {
427						Ok(wasmtime::Val::I32(i as i32))
428					} else if let Some(f) = n.as_f64() {
429						Ok(wasmtime::Val::F64(f.to_bits()))
430					} else {
431						Err(anyhow::anyhow!("Invalid number value"))
432					}
433				},
434				_ => Err(anyhow::anyhow!("Unsupported argument type")),
435			}
436		})
437		.collect()
438}
439
440/// Unmarshal return values from WASM function call
441pub fn unmarshal_return(val:wasmtime::Val) -> Result<Bytes> {
442	match val {
443		wasmtime::Val::I32(i) => {
444			let json = serde_json::to_string(&i)?;
445
446			Ok(Bytes::from(json))
447		},
448
449		wasmtime::Val::I64(i) => {
450			let json = serde_json::to_string(&i)?;
451
452			Ok(Bytes::from(json))
453		},
454
455		wasmtime::Val::F32(f) => {
456			let json = serde_json::to_string(&f)?;
457
458			Ok(Bytes::from(json))
459		},
460
461		wasmtime::Val::F64(f) => {
462			let json = serde_json::to_string(&f)?;
463
464			Ok(Bytes::from(json))
465		},
466
467		_ => Err(anyhow::anyhow!("Unsupported return type")),
468	}
469}
470
471#[cfg(test)]
472mod tests {
473
474	use super::*;
475
476	#[test]
477	fn test_function_signature_creation() {
478		let signature = FunctionSignature {
479			name:"test_func".to_string(),
480
481			param_types:vec![ParamType::I32, ParamType::Ptr],
482
483			return_type:Some(ReturnType::I32),
484
485			is_async:false,
486		};
487
488		assert_eq!(signature.name, "test_func");
489
490		assert_eq!(signature.param_types.len(), 2);
491	}
492
493	#[tokio::test]
494	async fn test_host_bridge_creation() {
495		let bridge = HostBridgeImpl::new();
496
497		assert_eq!(bridge.get_host_functions().await.len(), 0);
498	}
499
500	#[tokio::test]
501	async fn test_register_host_function() {
502		let bridge = HostBridgeImpl::new();
503
504		let signature = FunctionSignature {
505			name:"echo".to_string(),
506
507			param_types:vec![ParamType::I32],
508
509			return_type:Some(ReturnType::I32),
510
511			is_async:false,
512		};
513
514		let result = bridge
515			.register_host_function("echo", signature, |args| Ok(args[0].clone()))
516			.await;
517
518		assert!(result.is_ok());
519
520		assert_eq!(bridge.get_host_functions().await.len(), 1);
521	}
522
523	#[test]
524	fn test_serialize_deserialize() {
525		let data = vec![1, 2, 3, 4, 5];
526
527		let bytes = serialize_to_bytes(&data).unwrap();
528
529		let recovered:Vec<i32> = deserialize_from_bytes(&bytes).unwrap();
530
531		assert_eq!(data, recovered);
532	}
533
534	#[test]
535	fn test_marshal_unmarshal() {
536		let args = vec![serialize_to_bytes(&42i32).unwrap(), serialize_to_bytes(&3.14f64).unwrap()];
537
538		// Test that marshaling works (we don't assert on exact type conversion)
539		let marshaled = marshal_args(args);
540
541		assert!(marshaled.is_ok());
542	}
543}