1use std::{collections::HashMap, sync::Arc};
8
9use anyhow::Result;
10use bytes::Bytes;
11use serde::{Serialize, de::DeserializeOwned};
12use tokio::sync::{RwLock, mpsc, oneshot};
13#[allow(unused_imports)]
14use wasmtime::{Caller, Extern, Func, Linker, Store};
15
16use crate::dev_log;
17
18#[derive(Debug, thiserror::Error)]
20pub enum BridgeError {
21 #[error("Function not found: {0}")]
23 FunctionNotFound(String),
24
25 #[error("Invalid function signature: {0}")]
27 InvalidSignature(String),
28
29 #[error("Serialization failed: {0}")]
31 SerializationError(String),
32
33 #[error("Deserialization failed: {0}")]
35 DeserializationError(String),
36
37 #[error("Host function error: {0}")]
39 HostFunctionError(String),
40
41 #[error("Communication timeout")]
43 Timeout,
44
45 #[error("Bridge closed")]
47 BridgeClosed,
48}
49
50pub type BridgeResult<T> = Result<T, BridgeError>;
52
53#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
55pub struct FunctionSignature {
56 pub name:String,
58
59 pub param_types:Vec<ParamType>,
61
62 pub return_type:Option<ReturnType>,
64
65 pub is_async:bool,
67}
68
69#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
71pub enum ParamType {
72 I32,
74
75 I64,
77
78 F32,
80
81 F64,
83
84 Ptr,
86
87 Len,
89}
90
91#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
93pub enum ReturnType {
94 I32,
96
97 I64,
99
100 F32,
102
103 F64,
105
106 Void,
108}
109
110#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
112pub struct HostMessage {
113 pub message_id:String,
115
116 pub function:String,
118
119 pub args:Vec<Bytes>,
121
122 pub callback_token:Option<u64>,
124}
125
126#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
128pub struct HostResponse {
129 pub message_id:String,
131
132 pub success:bool,
134
135 pub data:Option<Bytes>,
137
138 pub error:Option<String>,
140}
141
142#[derive(Clone)]
144pub struct AsyncCallback {
145 sender:Arc<tokio::sync::Mutex<Option<tokio::sync::oneshot::Sender<HostResponse>>>>,
147
148 message_id:String,
150}
151
152impl std::fmt::Debug for AsyncCallback {
153 fn fmt(&self, f:&mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 f.debug_struct("AsyncCallback").field("message_id", &self.message_id).finish()
155 }
156}
157
158impl AsyncCallback {
159 pub async fn send(self, response:HostResponse) -> Result<()> {
161 let mut sender_opt = self.sender.lock().await;
162
163 if let Some(sender) = sender_opt.take() {
164 sender.send(response).map_err(|_| BridgeError::BridgeClosed)?;
165
166 Ok(())
167 } else {
168 Err(BridgeError::BridgeClosed.into())
169 }
170 }
171}
172
173#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
175pub struct WASMMessage {
176 pub function:String,
178
179 pub args:Vec<Bytes>,
181}
182
183pub type HostFunctionCallback = fn(Vec<Bytes>) -> Result<Bytes>;
185
186pub type AsyncHostFunctionCallback =
188 fn(Vec<Bytes>) -> Box<dyn std::future::Future<Output = Result<Bytes>> + Send + Unpin>;
189
190#[derive(Debug)]
192pub struct HostFunction {
193 pub name:String,
195
196 pub signature:FunctionSignature,
198
199 #[allow(dead_code)]
201 pub callback:Option<HostFunctionCallback>,
202
203 #[allow(dead_code)]
205 pub async_callback:Option<AsyncHostFunctionCallback>,
206}
207
208#[derive(Debug)]
210pub struct HostBridgeImpl {
211 host_functions:Arc<RwLock<HashMap<String, HostFunction>>>,
213
214 wasm_to_host_rx:mpsc::UnboundedReceiver<WASMMessage>,
216
217 host_to_wasm_tx:mpsc::UnboundedSender<WASMMessage>,
219
220 async_callbacks:Arc<RwLock<HashMap<u64, AsyncCallback>>>,
222
223 next_callback_token:Arc<std::sync::atomic::AtomicU64>,
225}
226
227impl HostBridgeImpl {
228 pub fn new() -> Self {
230 let (_wasm_to_host_tx, wasm_to_host_rx) = mpsc::unbounded_channel();
231
232 let (host_to_wasm_tx, host_to_wasm_rx) = mpsc::unbounded_channel();
233
234 drop(host_to_wasm_rx);
237
238 Self {
239 host_functions:Arc::new(RwLock::new(HashMap::new())),
240
241 wasm_to_host_rx,
242
243 host_to_wasm_tx,
244
245 async_callbacks:Arc::new(RwLock::new(HashMap::new())),
246
247 next_callback_token:Arc::new(std::sync::atomic::AtomicU64::new(0)),
248 }
249 }
250
251 pub async fn register_host_function(
253 &self,
254
255 name:&str,
256
257 signature:FunctionSignature,
258
259 callback:HostFunctionCallback,
260 ) -> BridgeResult<()> {
261 dev_log!("wasm", "Registering host function: {}", name);
262
263 let mut functions = self.host_functions.write().await;
264
265 if functions.contains_key(name) {
266 dev_log!("wasm", "warn: host function already registered: {}", name);
267 }
268
269 functions.insert(
270 name.to_string(),
271 HostFunction { name:name.to_string(), signature, callback:Some(callback), async_callback:None },
272 );
273
274 dev_log!("wasm", "Host function registered successfully: {}", name);
275
276 Ok(())
277 }
278
279 pub async fn register_async_host_function(
281 &self,
282
283 name:&str,
284
285 signature:FunctionSignature,
286
287 callback:AsyncHostFunctionCallback,
288 ) -> BridgeResult<()> {
289 dev_log!("wasm", "Registering async host function: {}", name);
290
291 let mut functions = self.host_functions.write().await;
292
293 functions.insert(
294 name.to_string(),
295 HostFunction { name:name.to_string(), signature, callback:None, async_callback:Some(callback) },
296 );
297
298 dev_log!("wasm", "Async host function registered successfully: {}", name);
299
300 Ok(())
301 }
302
303 pub async fn call_host_function(&self, function_name:&str, args:Vec<Bytes>) -> BridgeResult<Bytes> {
305 dev_log!("wasm", "Calling host function: {}", function_name);
306
307 let functions = self.host_functions.read().await;
308
309 let func = functions
310 .get(function_name)
311 .ok_or_else(|| BridgeError::FunctionNotFound(function_name.to_string()))?;
312
313 if let Some(callback) = func.callback {
314 let result =
316 callback(args).map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
317
318 dev_log!("wasm", "Host function call completed: {}", function_name);
319
320 Ok(result)
321 } else if let Some(async_callback) = func.async_callback {
322 let future = async_callback(args);
324
325 let result = future
326 .await
327 .map_err(|e| BridgeError::HostFunctionError(format!("{}: {}", function_name, e)))?;
328
329 dev_log!("wasm", "Async host function call completed: {}", function_name);
330
331 Ok(result)
332 } else {
333 Err(BridgeError::FunctionNotFound(format!(
334 "No callback for function: {}",
335 function_name
336 )))
337 }
338 }
339
340 pub async fn send_to_wasm(&self, message:WASMMessage) -> BridgeResult<()> {
342 let function_name = message.function.clone();
343
344 self.host_to_wasm_tx.send(message).map_err(|_| BridgeError::BridgeClosed)?;
345
346 dev_log!("wasm", "Message sent to WASM: {}", function_name);
347
348 Ok(())
349 }
350
351 pub async fn receive_from_wasm(&mut self) -> Option<WASMMessage> { self.wasm_to_host_rx.recv().await }
353
354 pub async fn create_async_callback(&self, message_id:String) -> (AsyncCallback, u64) {
356 let token = self.next_callback_token.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
357
358 let (tx, _rx) = oneshot::channel();
359
360 let callback = AsyncCallback {
362 sender:Arc::new(tokio::sync::Mutex::new(Some(tx))),
363
364 message_id:message_id.clone(),
365 };
366
367 self.async_callbacks.write().await.insert(token, callback.clone());
368
369 (callback, token)
370 }
371
372 pub async fn get_callback(&self, token:u64) -> Option<AsyncCallback> {
374 self.async_callbacks.write().await.remove(&token)
375 }
376
377 pub async fn get_host_functions(&self) -> Vec<String> { self.host_functions.read().await.keys().cloned().collect() }
379
380 pub async fn unregister_host_function(&self, name:&str) -> bool {
382 let mut functions = self.host_functions.write().await;
383
384 let removed = functions.remove(name).is_some();
385
386 if removed {
387 dev_log!("wasm", "Host function unregistered: {}", name);
388 }
389
390 removed
391 }
392
393 pub async fn clear(&self) {
395 dev_log!("wasm", "Clearing all registered host functions");
396
397 self.host_functions.write().await.clear();
398
399 self.async_callbacks.write().await.clear();
400 }
401}
402
403impl Default for HostBridgeImpl {
404 fn default() -> Self { Self::new() }
405}
406
407pub fn serialize_to_bytes<T:Serialize>(data:&T) -> Result<Bytes> {
409 serde_json::to_vec(data)
410 .map(Bytes::from)
411 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
412}
413
414pub fn deserialize_from_bytes<T:DeserializeOwned>(bytes:&Bytes) -> Result<T> {
416 serde_json::from_slice(bytes).map_err(|e| anyhow::anyhow!("Deserialization error: {}", e))
417}
418
419pub fn marshal_args(args:Vec<Bytes>) -> Result<Vec<wasmtime::Val>> {
421 args.iter()
422 .map(|bytes| {
423 let value:serde_json::Value = serde_json::from_slice(bytes)?;
424 match value {
425 serde_json::Value::Number(n) => {
426 if let Some(i) = n.as_i64() {
427 Ok(wasmtime::Val::I32(i as i32))
428 } else if let Some(f) = n.as_f64() {
429 Ok(wasmtime::Val::F64(f.to_bits()))
430 } else {
431 Err(anyhow::anyhow!("Invalid number value"))
432 }
433 },
434 _ => Err(anyhow::anyhow!("Unsupported argument type")),
435 }
436 })
437 .collect()
438}
439
440pub fn unmarshal_return(val:wasmtime::Val) -> Result<Bytes> {
442 match val {
443 wasmtime::Val::I32(i) => {
444 let json = serde_json::to_string(&i)?;
445
446 Ok(Bytes::from(json))
447 },
448
449 wasmtime::Val::I64(i) => {
450 let json = serde_json::to_string(&i)?;
451
452 Ok(Bytes::from(json))
453 },
454
455 wasmtime::Val::F32(f) => {
456 let json = serde_json::to_string(&f)?;
457
458 Ok(Bytes::from(json))
459 },
460
461 wasmtime::Val::F64(f) => {
462 let json = serde_json::to_string(&f)?;
463
464 Ok(Bytes::from(json))
465 },
466
467 _ => Err(anyhow::anyhow!("Unsupported return type")),
468 }
469}
470
471#[cfg(test)]
472mod tests {
473
474 use super::*;
475
476 #[test]
477 fn test_function_signature_creation() {
478 let signature = FunctionSignature {
479 name:"test_func".to_string(),
480
481 param_types:vec![ParamType::I32, ParamType::Ptr],
482
483 return_type:Some(ReturnType::I32),
484
485 is_async:false,
486 };
487
488 assert_eq!(signature.name, "test_func");
489
490 assert_eq!(signature.param_types.len(), 2);
491 }
492
493 #[tokio::test]
494 async fn test_host_bridge_creation() {
495 let bridge = HostBridgeImpl::new();
496
497 assert_eq!(bridge.get_host_functions().await.len(), 0);
498 }
499
500 #[tokio::test]
501 async fn test_register_host_function() {
502 let bridge = HostBridgeImpl::new();
503
504 let signature = FunctionSignature {
505 name:"echo".to_string(),
506
507 param_types:vec![ParamType::I32],
508
509 return_type:Some(ReturnType::I32),
510
511 is_async:false,
512 };
513
514 let result = bridge
515 .register_host_function("echo", signature, |args| Ok(args[0].clone()))
516 .await;
517
518 assert!(result.is_ok());
519
520 assert_eq!(bridge.get_host_functions().await.len(), 1);
521 }
522
523 #[test]
524 fn test_serialize_deserialize() {
525 let data = vec![1, 2, 3, 4, 5];
526
527 let bytes = serialize_to_bytes(&data).unwrap();
528
529 let recovered:Vec<i32> = deserialize_from_bytes(&bytes).unwrap();
530
531 assert_eq!(data, recovered);
532 }
533
534 #[test]
535 fn test_marshal_unmarshal() {
536 let args = vec![serialize_to_bytes(&42i32).unwrap(), serialize_to_bytes(&3.14f64).unwrap()];
537
538 let marshaled = marshal_args(args);
540
541 assert!(marshaled.is_ok());
542 }
543}