Skip to main content

Grove/Transport/
WASMTransport.rs

1//! WASM Transport Implementation
2//!
3//! Provides direct communication with WASM modules.
4//! Handles calls to and from WebAssembly instances.
5
6use std::{collections::HashMap, path::PathBuf, sync::Arc};
7
8use async_trait::async_trait;
9use base64::Engine;
10use bytes::Bytes;
11use serde::{Deserialize, Serialize};
12use tokio::sync::RwLock;
13
14use crate::{
15	Transport::{
16		Strategy::{TransportStats, TransportStrategy, TransportType},
17		TransportConfig,
18	},
19	WASM::{
20		HostBridge::HostBridgeImpl,
21		MemoryManager::{MemoryLimits, MemoryManagerImpl},
22		Runtime::{WASMConfig, WASMRuntime},
23		WASMStats,
24	},
25	dev_log,
26};
27
28/// WASM transport for direct module communication
29#[derive(Clone, Debug)]
30pub struct WASMTransportImpl {
31	/// WASM runtime
32	runtime:Arc<WASMRuntime>,
33
34	/// Memory manager
35	memory_manager:Arc<RwLock<MemoryManagerImpl>>,
36
37	/// Host bridge for communication
38	bridge:Arc<HostBridgeImpl>,
39
40	/// Loaded modules
41	modules:Arc<RwLock<HashMap<String, WASMModuleInfo>>>,
42
43	/// Transport configuration
44	#[allow(dead_code)]
45	config:TransportConfig,
46
47	/// Connection state
48	connected:Arc<RwLock<bool>>,
49
50	/// Transport statistics
51	stats:Arc<RwLock<TransportStats>>,
52}
53
54/// Information about a loaded WASM module
55#[derive(Debug, Clone)]
56pub struct WASMModuleInfo {
57	/// Module ID
58	pub id:String,
59
60	/// Module name (if available)
61	pub name:Option<String>,
62
63	/// Path to module file
64	pub path:Option<PathBuf>,
65
66	/// Module loaded timestamp
67	pub loaded_at:u64,
68
69	/// Function statistics
70	pub function_stats:HashMap<String, FunctionCallStats>,
71}
72
73/// Statistics for function calls
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct FunctionCallStats {
76	/// Number of calls
77	pub call_count:u64,
78
79	/// Total execution time in microseconds
80	pub total_time_us:u64,
81
82	/// Last call timestamp
83	pub last_call_at:Option<u64>,
84
85	/// Number of errors
86	pub error_count:u64,
87}
88
89impl FunctionCallStats {
90	/// Record a successful function call
91	pub fn record_call(&mut self, time_us:u64) {
92		self.call_count += 1;
93
94		self.total_time_us += time_us;
95
96		self.last_call_at = Some(
97			std::time::SystemTime::now()
98				.duration_since(std::time::UNIX_EPOCH)
99				.map(|d| d.as_secs())
100				.unwrap_or(0),
101		);
102	}
103
104	/// Record a failed function call
105	pub fn record_error(&mut self) { self.error_count += 1; }
106}
107
108impl Default for FunctionCallStats {
109	fn default() -> Self { Self { call_count:0, total_time_us:0, last_call_at:None, error_count:0 } }
110}
111
112impl WASMTransportImpl {
113	/// Create a new WASM transport with default configuration
114	pub fn new(enable_wasi:bool, memory_limit_mb:u64, max_execution_time_ms:u64) -> anyhow::Result<Self> {
115		let config = WASMConfig::new(memory_limit_mb, max_execution_time_ms, enable_wasi);
116
117		// Create runtime - this would normally be async, but for now we do it
118		// synchronously In production, this would need to be properly awaited
119		let runtime_result = tokio::runtime::Runtime::new()
120			.map_err(|e| anyhow::anyhow!("Failed to create tokio runtime: {}", e))?
121			.block_on(WASMRuntime::new(config.clone()))
122			.map_err(|e| anyhow::anyhow!("Failed to create WASM runtime: {}", e))?;
123
124		let runtime = Arc::new(runtime_result);
125
126		let memory_limits = MemoryLimits::new(memory_limit_mb, (memory_limit_mb as f64 * 0.75) as u64, 100);
127
128		let memory_manager = Arc::new(RwLock::new(MemoryManagerImpl::new(memory_limits)));
129
130		let bridge = Arc::new(HostBridgeImpl::new());
131
132		Ok(Self {
133			runtime,
134			memory_manager,
135			bridge,
136			modules:Arc::new(RwLock::new(HashMap::new())),
137			config:TransportConfig::default(),
138			connected:Arc::new(RwLock::new(true)), // WASM transport is always "connected" locally
139			stats:Arc::new(RwLock::new(TransportStats::default())),
140		})
141	}
142
143	/// Create a new WASM transport with custom configuration
144	pub fn with_config(wasm_config:WASMConfig, transport_config:TransportConfig) -> anyhow::Result<Self> {
145		let runtime_result = tokio::runtime::Runtime::new()
146			.map_err(|e| anyhow::anyhow!("Failed to create tokio runtime: {}", e))?
147			.block_on(WASMRuntime::new(wasm_config.clone()))
148			.map_err(|e| anyhow::anyhow!("Failed to create WASM runtime: {}", e))?;
149
150		let runtime = Arc::new(runtime_result);
151
152		let memory_limits = MemoryLimits::new(
153			wasm_config.memory_limit_mb,
154			(wasm_config.memory_limit_mb as f64 * 0.75) as u64,
155			100,
156		);
157
158		let memory_manager = Arc::new(RwLock::new(MemoryManagerImpl::new(memory_limits)));
159
160		let bridge = Arc::new(HostBridgeImpl::new());
161
162		Ok(Self {
163			runtime,
164			memory_manager,
165			bridge,
166			modules:Arc::new(RwLock::new(HashMap::new())),
167			config:transport_config,
168			connected:Arc::new(RwLock::new(true)),
169			stats:Arc::new(RwLock::new(TransportStats::default())),
170		})
171	}
172
173	/// Get a reference to the WASM runtime
174	pub fn runtime(&self) -> &Arc<WASMRuntime> { &self.runtime }
175
176	/// Get a reference to the memory manager
177	pub fn memory_manager(&self) -> &Arc<RwLock<MemoryManagerImpl>> { &self.memory_manager }
178
179	/// Get a reference to the host bridge
180	pub fn bridge(&self) -> &Arc<HostBridgeImpl> { &self.bridge }
181
182	/// Get all loaded modules
183	pub async fn get_modules(&self) -> HashMap<String, WASMModuleInfo> { self.modules.read().await.clone() }
184
185	/// Get WASM runtime statistics
186	pub async fn get_wasm_stats(&self) -> WASMStats {
187		let memory_manager = self.memory_manager.read().await;
188
189		let managers = self.modules.read().await;
190
191		WASMStats {
192			modules_loaded:managers.len(),
193
194			active_instances:managers.len(), // In real implementation, track instances
195			total_memory_mb:memory_manager.current_usage_mb() as u64,
196
197			total_execution_time_ms:0, // Track from actual calls
198			function_calls:self.stats.read().await.messages_sent,
199		}
200	}
201
202	/// Call a function in a WASM module
203	pub async fn call_wasm_function(
204		&self,
205
206		module_id:&str,
207
208		function_name:&str,
209
210		args:Vec<Bytes>,
211	) -> anyhow::Result<Bytes> {
212		let start = std::time::Instant::now();
213
214		dev_log!(
215			"wasm",
216			"Calling WASM function: {}::{} with {} arguments",
217			module_id,
218			function_name,
219			args.len()
220		);
221
222		let modules = self.modules.read().await;
223
224		let _module = modules
225			.get(module_id)
226			.ok_or_else(|| anyhow::anyhow!("Module not found: {}", module_id))?;
227
228		// In a real implementation, this would call the actual WASM function
229		// For now, we return a mock response
230		let response = Bytes::new();
231
232		// Update statistics
233		let mut modules_mut = self.modules.write().await;
234
235		if let Some(module) = modules_mut.get_mut(module_id) {
236			let stats = module.function_stats.entry(function_name.to_string()).or_default();
237
238			stats.record_call(start.elapsed().as_micros() as u64);
239		}
240
241		drop(modules_mut);
242
243		// Update transport statistics
244		let mut stats = self.stats.write().await;
245
246		stats.record_sent(args.iter().map(|b| b.len() as u64).sum(), start.elapsed().as_micros() as u64);
247
248		stats.record_received(response.len() as u64);
249
250		Ok(response)
251	}
252}
253
254#[async_trait]
255impl TransportStrategy for WASMTransportImpl {
256	type Error = WASMTransportError;
257
258	async fn connect(&self) -> Result<(), Self::Error> {
259		dev_log!("transport", "WASM transport connecting");
260
261		// WASM transport is always "connected" locally
262		*self.connected.write().await = true;
263
264		dev_log!("transport", "WASM transport connected");
265
266		Ok(())
267	}
268
269	async fn send(&self, request:&[u8]) -> Result<Vec<u8>, Self::Error> {
270		let start = std::time::Instant::now();
271
272		if !self.is_connected() {
273			return Err(WASMTransportError::NotConnected);
274		}
275
276		dev_log!("transport", "Sending WASM transport request ({} bytes)", request.len());
277
278		// Parse request - it should contain module ID and function name
279		// For simplicity, we use a minimal format: module_id:function_name:base64_args
280		let request_str =
281			std::str::from_utf8(request).map_err(|e| WASMTransportError::InvalidRequest(e.to_string()))?;
282
283		let parts:Vec<&str> = request_str.splitn(3, ':').collect();
284
285		if parts.len() < 3 {
286			return Err(WASMTransportError::InvalidRequest("Invalid request format".to_string()));
287		}
288
289		let module_id = parts[0];
290
291		let function_name = parts[1];
292
293		let args_base64 = parts[2];
294
295		// Decode arguments from base64
296		use base64::engine::general_purpose::STANDARD;
297
298		let args = vec![Bytes::from(
299			STANDARD
300				.decode(args_base64)
301				.map_err(|e| WASMTransportError::InvalidRequest(e.to_string()))?,
302		)];
303
304		// Call the WASM function
305		let response = self
306			.call_wasm_function(module_id, function_name, args)
307			.await
308			.map_err(|e| WASMTransportError::FunctionCallFailed(e.to_string()))?;
309
310		// Convert response to Vec<u8>
311		let response_vec = response.to_vec();
312
313		let latency_us = start.elapsed().as_micros() as u64;
314
315		dev_log!("transport", "WASM transport request completed in {}µs", latency_us);
316
317		Ok(response_vec)
318	}
319
320	async fn send_no_response(&self, data:&[u8]) -> Result<(), Self::Error> {
321		if !self.is_connected() {
322			return Err(WASMTransportError::NotConnected);
323		}
324
325		dev_log!(
326			"transport",
327			"Sending WASM transport request without response ({} bytes)",
328			data.len()
329		);
330
331		// For fire-and-forget calls, we still execute but ignore the response
332		self.send(data).await?;
333
334		Ok(())
335	}
336
337	async fn close(&self) -> Result<(), Self::Error> {
338		dev_log!("transport", "Closing WASM transport");
339
340		*self.connected.write().await = false;
341
342		dev_log!("transport", "WASM transport closed");
343
344		Ok(())
345	}
346
347	fn is_connected(&self) -> bool { self.connected.blocking_read().to_owned() }
348
349	fn transport_type(&self) -> TransportType { TransportType::WASM }
350}
351
352/// WASM transport errors
353#[derive(Debug, thiserror::Error)]
354pub enum WASMTransportError {
355	/// Module not found error
356	#[error("Module not found: {0}")]
357	ModuleNotFound(String),
358
359	/// Function not found error
360	#[error("Function not found: {0}")]
361	FunctionNotFound(String),
362
363	/// Function call failed error
364	#[error("Function call failed: {0}")]
365	FunctionCallFailed(String),
366
367	/// Memory error
368	#[error("Memory error: {0}")]
369	MemoryError(String),
370
371	/// Runtime error
372	#[error("Runtime error: {0}")]
373	RuntimeError(String),
374
375	/// Invalid request error
376	#[error("Invalid request: {0}")]
377	InvalidRequest(String),
378
379	/// Not connected error
380	#[error("Not connected")]
381	NotConnected,
382
383	/// Compilation failed error
384	#[error("Compilation failed: {0}")]
385	CompilationFailed(String),
386
387	/// Timeout error
388	#[error("Timeout")]
389	Timeout,
390}
391
392#[cfg(test)]
393mod tests {
394
395	use super::*;
396	use crate::Transport::Strategy::TransportStrategy;
397
398	#[test]
399	fn test_wasm_transport_creation() {
400		let result = WASMTransportImpl::new(true, 512, 30000);
401
402		assert!(result.is_ok());
403
404		let transport = result.unwrap();
405
406		// WASM transport should always be connected
407		assert!(transport.is_connected());
408	}
409
410	#[test]
411	fn test_function_call_stats() {
412		let mut stats = FunctionCallStats::default();
413
414		stats.record_call(100);
415
416		assert_eq!(stats.call_count, 1);
417
418		assert_eq!(stats.total_time_us, 100);
419
420		assert!(stats.last_call_at.is_some());
421	}
422
423	#[tokio::test]
424	async fn test_wasm_transport_not_connected_after_close() {
425		let transport = WASMTransportImpl::new(true, 512, 30000).unwrap();
426
427		let _:anyhow::Result<()> = transport.close().await.map_err(|e| anyhow::anyhow!(e.to_string()));
428
429		assert!(!transport.is_connected());
430	}
431
432	#[tokio::test]
433	async fn test_get_wasm_stats() {
434		let transport = WASMTransportImpl::new(true, 512, 30000).unwrap();
435
436		let stats = transport.get_wasm_stats().await;
437
438		assert_eq!(stats.modules_loaded, 0);
439
440		assert_eq!(stats.active_instances, 0);
441	}
442}