1use std::{collections::HashMap, sync::Arc};
7
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use tokio::sync::RwLock;
11use wasmtime::{Caller, Linker};
12
13use crate::{
14 WASM::HostBridge::{
15 FunctionSignature,
16 HostBridgeImpl,
17 HostBridgeImpl as HostBridge,
18 HostFunctionCallback,
19 ParamType,
20 ReturnType,
21 },
22 dev_log,
23};
24
25pub struct HostFunctionRegistry {
27 functions:Arc<RwLock<HashMap<String, RegisteredHostFunction>>>,
29
30 #[allow(dead_code)]
32 bridge:Arc<HostBridge>,
33}
34
35#[derive(Debug, Clone)]
37struct RegisteredHostFunction {
38 #[allow(dead_code)]
40 name:String,
41
42 #[allow(dead_code)]
44 signature:FunctionSignature,
45
46 callback:Option<HostFunctionCallback>,
48
49 #[allow(dead_code)]
51 registered_at:u64,
52
53 stats:FunctionStats,
55}
56
57#[derive(Debug, Clone, Default)]
59pub struct FunctionStats {
60 pub call_count:u64,
62
63 pub total_execution_ns:u64,
65
66 pub last_call_at:Option<u64>,
68
69 pub error_count:u64,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct ExportConfig {
76 pub auto_export:bool,
78
79 pub enable_stats:bool,
81
82 pub max_functions:usize,
84
85 pub name_prefix:Option<String>,
87}
88
89impl Default for ExportConfig {
90 fn default() -> Self {
91 Self {
92 auto_export:true,
93
94 enable_stats:true,
95
96 max_functions:1000,
97
98 name_prefix:Some("host_".to_string()),
99 }
100 }
101}
102
103pub struct FunctionExportImpl {
105 registry:Arc<HostFunctionRegistry>,
106
107 config:ExportConfig,
108}
109
110impl FunctionExportImpl {
111 pub fn new(bridge:Arc<HostBridge>) -> Self {
113 Self {
114 registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
115
116 config:ExportConfig::default(),
117 }
118 }
119
120 pub fn with_config(bridge:Arc<HostBridge>, config:ExportConfig) -> Self {
122 Self {
123 registry:Arc::new(HostFunctionRegistry { functions:Arc::new(RwLock::new(HashMap::new())), bridge }),
124
125 config,
126 }
127 }
128
129 pub async fn register_function(
131 &self,
132
133 name:&str,
134
135 signature:FunctionSignature,
136
137 callback:HostFunctionCallback,
138 ) -> Result<()> {
139 dev_log!("wasm", "Registering host function for export: {}", name);
140
141 let functions = self.registry.functions.read().await;
142
143 if functions.len() >= self.config.max_functions {
145 return Err(anyhow::anyhow!(
146 "Maximum number of exported functions reached: {}",
147 self.config.max_functions
148 ));
149 }
150
151 drop(functions);
152
153 let mut functions = self.registry.functions.write().await;
154
155 let registered_at = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH)?.as_secs();
156
157 functions.insert(
158 name.to_string(),
159 RegisteredHostFunction {
160 name:name.to_string(),
161 signature,
162 callback:Some(callback),
163 registered_at,
164 stats:FunctionStats::default(),
165 },
166 );
167
168 dev_log!("wasm", "Host function registered for WASM export: {}", name);
169
170 Ok(())
171 }
172
173 pub async fn register_functions(
175 &self,
176
177 signatures:Vec<FunctionSignature>,
178
179 callbacks:Vec<HostFunctionCallback>,
180 ) -> Result<()> {
181 if signatures.len() != callbacks.len() {
182 return Err(anyhow::anyhow!("Number of signatures must match number of callbacks"));
183 }
184
185 for (sig, callback) in signatures.into_iter().zip(callbacks) {
186 let name = sig.name.clone();
187
188 self.register_function(&name, sig, callback).await?;
189 }
190
191 Ok(())
192 }
193
194 pub async fn export_to_linker<T>(&self, linker:&mut Linker<T>) -> Result<()>
196 where
197 T: Send + 'static, {
198 dev_log!(
199 "wasm",
200 "Exporting {} host functions to linker",
201 self.registry.functions.read().await.len()
202 );
203
204 let functions = self.registry.functions.read().await;
205
206 for (name, func) in functions.iter() {
207 self.export_single_function(linker, name, func)?;
208 }
209
210 dev_log!("wasm", "All host functions exported to linker");
211
212 Ok(())
213 }
214
215 fn export_single_function<T>(&self, linker:&mut Linker<T>, name:&str, func:&RegisteredHostFunction) -> Result<()>
217 where
218 T: Send + 'static, {
219 dev_log!("wasm", "Exporting function: {}", name);
220
221 let callback = func
222 .callback
223 .ok_or_else(|| anyhow::anyhow!("No callback available for function: {}", name))?;
224
225 let func_name = if let Some(prefix) = &self.config.name_prefix {
226 format!("{}{}", prefix, name)
227 } else {
228 name.to_string()
229 };
230
231 let func_name_for_debug = func_name.clone();
232
233 let func_name_inner = func_name.clone();
234
235 let _wrapped_callback =
237 move |_caller:Caller<'_, T>, args:&[wasmtime::Val]| -> Result<Vec<wasmtime::Val>, wasmtime::Trap> {
238 let _start = std::time::Instant::now();
239
240 let args_bytes:Result<Vec<bytes::Bytes>, _> = args
242 .iter()
243 .map(|arg| {
244 match arg {
245 wasmtime::Val::I32(i) => {
246 serde_json::to_vec(i)
247 .map(bytes::Bytes::from)
248 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
249 },
250 wasmtime::Val::I64(i) => {
251 serde_json::to_vec(i)
252 .map(bytes::Bytes::from)
253 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
254 },
255 wasmtime::Val::F32(f) => {
256 serde_json::to_vec(f)
257 .map(bytes::Bytes::from)
258 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
259 },
260 wasmtime::Val::F64(f) => {
261 serde_json::to_vec(f)
262 .map(bytes::Bytes::from)
263 .map_err(|e| anyhow::anyhow!("Serialization error: {}", e))
264 },
265 _ => Err(anyhow::anyhow!("Unsupported argument type")),
266 }
267 })
268 .collect();
269
270 let args_bytes = args_bytes.map_err(|_| {
271 dev_log!("wasm", "warn: error converting arguments for function '{}'", func_name_inner);
272 wasmtime::Trap::StackOverflow
273 })?;
274
275 let result = callback(args_bytes);
277
278 match result {
279 Ok(response_bytes) => {
280 let result_val:serde_json::Value = serde_json::from_slice(&response_bytes).map_err(|_| {
282 dev_log!("wasm", "warn: error deserializing response for function '{}'", func_name_inner);
283 wasmtime::Trap::StackOverflow
284 })?;
285
286 let ret_val = match result_val {
287 serde_json::Value::Number(n) => {
288 if let Some(i) = n.as_i64() {
289 wasmtime::Val::I32(i as i32)
290 } else if let Some(f) = n.as_f64() {
291 wasmtime::Val::I64(f as i64)
292 } else {
293 dev_log!("wasm", "warn: invalid number format for function '{}'", func_name_inner);
294
295 return Err(wasmtime::Trap::StackOverflow);
296 }
297 },
298
299 _ => {
300 dev_log!("wasm", "warn: unsupported response type for function '{}'", func_name_inner);
301
302 return Err(wasmtime::Trap::StackOverflow);
303 },
304 };
305
306 Ok(vec![ret_val])
307 },
308
309 Err(e) => {
310 dev_log!("wasm", "host function '{}' returned error: {}", func_name_inner, e);
312
313 Err(wasmtime::Trap::StackOverflow)
314 },
315 }
316 };
317
318 let _wasmparser_signature = wasmparser::FuncType::new([wasmparser::ValType::I32], [wasmparser::ValType::I32]);
320
321 let func_name_for_logging = func_name.clone();
325
326 linker
327 .func_wrap(
328 "_host", &func_name,
330 move |_caller:wasmtime::Caller<'_, T>, input_param:i32| -> i32 {
331 let start = std::time::Instant::now();
333
334 let args_bytes = match serde_json::to_vec(&input_param).map(bytes::Bytes::from) {
336 Ok(b) => b,
337 Err(e) => {
338 dev_log!(
339 "wasm",
340 "warn: serialization error for function '{}': {}",
341 func_name_for_logging,
342 e
343 );
344 return -1i32;
345 },
346 };
347
348 let result = callback(vec![args_bytes]);
350
351 match result {
352 Ok(response_bytes) => {
353 let result_val:serde_json::Value = match serde_json::from_slice(&response_bytes) {
355 Ok(v) => v,
356 Err(_) => {
357 dev_log!(
358 "wasm",
359 "warn: error deserializing response for function '{}'",
360 func_name_for_logging
361 );
362 return -1i32;
363 },
364 };
365
366 let ret_val = match result_val {
368 serde_json::Value::Number(n) => {
369 if let Some(i) = n.as_i64() {
370 i as i32
371 } else if let Some(f) = n.as_f64() {
372 f as i32
373 } else {
374 dev_log!(
375 "wasm",
376 "warn: invalid number format for function '{}'",
377 func_name_for_logging
378 );
379 -1i32
380 }
381 },
382 serde_json::Value::Bool(b) => {
383 if b {
384 1i32
385 } else {
386 0i32
387 }
388 },
389 _ => {
390 dev_log!(
391 "wasm",
392 "warn: unsupported response type for function '{}', expected number or bool",
393 func_name_for_logging
394 );
395 -1i32
396 },
397 };
398
399 let duration = start.elapsed();
401 dev_log!(
402 "wasm",
403 "[FunctionExport] Host function '{}' executed successfully in {}µs",
404 func_name_for_logging,
405 duration.as_micros()
406 );
407
408 ret_val
409 },
410 Err(e) => {
411 dev_log!(
413 "wasm",
414 "[FunctionExport] Host function '{}' returned error: {}",
415 func_name_for_logging,
416 e
417 );
418 -1i32
420 },
421 }
422 },
423 )
424 .map_err(|e| {
425 dev_log!(
426 "wasm",
427 "warn: [FunctionExport] failed to wrap host function '{}': {}",
428 func_name_for_debug,
429 e
430 );
431 e
432 })?;
433
434 dev_log!(
435 "wasm",
436 "[FunctionExport] Host function '{}' registered successfully",
437 func_name_for_debug
438 );
439
440 Ok(())
441 }
442
443 #[allow(dead_code)]
445 fn wasmtime_signature_from_signature(&self, _sig:&FunctionSignature) -> Result<wasmparser::FuncType> {
446 Ok(wasmparser::FuncType::new([], []))
449 }
450
451 pub async fn get_function_names(&self) -> Vec<String> {
453 self.registry.functions.read().await.keys().cloned().collect()
454 }
455
456 pub async fn get_function_stats(&self, name:&str) -> Option<FunctionStats> {
458 self.registry.functions.read().await.get(name).map(|f| f.stats.clone())
459 }
460
461 pub async fn unregister_function(&self, name:&str) -> Result<bool> {
463 let mut functions = self.registry.functions.write().await;
464
465 let removed = functions.remove(name).is_some();
466
467 if removed {
468 dev_log!("wasm", "Unregistered host function: {}", name);
469 } else {
470 dev_log!("wasm", "warn: attempted to unregister non-existent function: {}", name);
471 }
472
473 Ok(removed)
474 }
475
476 pub async fn clear(&self) {
478 dev_log!("wasm", "Clearing all registered host functions");
479
480 self.registry.functions.write().await.clear();
481 }
482}
483
484#[cfg(test)]
485mod tests {
486
487 use super::*;
488
489 #[tokio::test]
490 async fn test_function_export_creation() {
491 let bridge = Arc::new(HostBridgeImpl::new());
492
493 let export = FunctionExportImpl::new(bridge);
494
495 assert_eq!(export.get_function_names().await.len(), 0);
496 }
497
498 #[tokio::test]
499 async fn test_register_function() {
500 let bridge = Arc::new(HostBridgeImpl::new());
501
502 let export = FunctionExportImpl::new(bridge);
503
504 let signature = FunctionSignature {
505 name:"echo".to_string(),
506
507 param_types:vec![ParamType::I32],
508
509 return_type:Some(ReturnType::I32),
510
511 is_async:false,
512 };
513
514 let callback = |args:Vec<bytes::Bytes>| Ok(args.get(0).cloned().unwrap_or(bytes::Bytes::new()));
515
516 let result:anyhow::Result<()> = export.register_function("echo", signature, callback).await;
517
518 assert!(result.is_ok());
519
520 assert_eq!(export.get_function_names().await.len(), 1);
521 }
522
523 #[tokio::test]
524 async fn test_unregister_function() {
525 let bridge = Arc::new(HostBridgeImpl::new());
526
527 let export = FunctionExportImpl::new(bridge);
528
529 let signature = FunctionSignature {
530 name:"test".to_string(),
531
532 param_types:vec![ParamType::I32],
533
534 return_type:Some(ReturnType::I32),
535
536 is_async:false,
537 };
538
539 let callback = |_:Vec<bytes::Bytes>| Ok(bytes::Bytes::new());
540
541 let _:anyhow::Result<()> = export.register_function("test", signature, callback).await;
542
543 let result:bool = export.unregister_function("test").await.unwrap();
544
545 assert!(result);
546
547 assert_eq!(export.get_function_names().await.len(), 0);
548 }
549
550 #[test]
551 fn test_export_config_default() {
552 let config = ExportConfig::default();
553
554 assert_eq!(config.auto_export, true);
555
556 assert_eq!(config.max_functions, 1000);
557 }
558
559 #[test]
560 fn test_function_stats_default() {
561 let stats = FunctionStats::default();
562
563 assert_eq!(stats.call_count, 0);
564
565 assert_eq!(stats.error_count, 0);
566 }
567}