1use std::{collections::HashMap, sync::Arc};
7
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use tokio::sync::RwLock;
11
12use crate::dev_log;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct APICallRequest {
17 pub extension_id:String,
19
20 pub api_method:String,
22
23 pub arguments:Vec<serde_json::Value>,
25
26 pub correlation_id:Option<String>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct APICallResponse {
33 pub success:bool,
35
36 pub data:Option<serde_json::Value>,
38
39 pub error:Option<String>,
41
42 pub correlation_id:Option<String>,
44}
45
46#[allow(dead_code)]
48pub struct APICall {
49 extension_id:String,
51
52 api_method:String,
54
55 arguments:Vec<serde_json::Value>,
57
58 timestamp:u64,
60}
61
62#[allow(dead_code)]
64type APIMethodHandler = fn(&str, Vec<serde_json::Value>) -> Result<serde_json::Value>;
65
66#[allow(dead_code)]
68type AsyncAPIMethodHandler =
69 fn(&str, Vec<serde_json::Value>) -> Box<dyn std::future::Future<Output = Result<serde_json::Value>> + Send + Unpin>;
70
71#[derive(Clone)]
73pub struct APIMethodInfo {
74 #[allow(dead_code)]
76 name:String,
77
78 #[allow(dead_code)]
80 description:String,
81
82 #[allow(dead_code)]
84 parameters:Option<serde_json::Value>,
85
86 #[allow(dead_code)]
88 returns:Option<serde_json::Value>,
89
90 #[allow(dead_code)]
92 is_async:bool,
93
94 call_count:u64,
96
97 total_time_us:u64,
99}
100
101pub struct APIBridgeImpl {
103 api_methods:Arc<RwLock<HashMap<String, APIMethodInfo>>>,
105
106 stats:Arc<RwLock<APIStats>>,
108
109 contexts:Arc<RwLock<HashMap<String, APIContext>>>,
111}
112
113#[derive(Debug, Clone, Default, Serialize, Deserialize)]
115pub struct APIStats {
116 pub total_calls:u64,
118
119 pub successful_calls:u64,
121
122 pub failed_calls:u64,
124
125 pub avg_latency_us:u64,
127
128 pub active_contexts:usize,
130}
131
132#[derive(Debug, Clone)]
134pub struct APIContext {
135 pub extension_id:String,
137
138 pub context_id:String,
140
141 pub workspace_folder:Option<String>,
143
144 pub active_editor:Option<String>,
146
147 pub selections:Vec<Selection>,
149
150 pub created_at:u64,
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct Selection {
157 pub start_line:u32,
159
160 pub start_character:u32,
162
163 pub end_line:u32,
165
166 pub end_character:u32,
168}
169
170impl Default for Selection {
171 fn default() -> Self { Self { start_line:0, start_character:0, end_line:0, end_character:0 } }
172}
173
174impl APIBridgeImpl {
175 pub fn new() -> Self {
177 let bridge = Self {
178 api_methods:Arc::new(RwLock::new(HashMap::new())),
179
180 stats:Arc::new(RwLock::new(APIStats::default())),
181
182 contexts:Arc::new(RwLock::new(HashMap::new())),
183 };
184
185 bridge.register_builtin_methods();
186
187 bridge
188 }
189
190 fn register_builtin_methods(&self) {
192 dev_log!("extensions", "Registered built-in VS Code API methods");
201 }
202
203 pub async fn register_method(
205 &self,
206
207 name:&str,
208
209 description:&str,
210
211 parameters:Option<serde_json::Value>,
212
213 returns:Option<serde_json::Value>,
214
215 is_async:bool,
216 ) -> Result<()> {
217 let mut methods = self.api_methods.write().await;
218
219 if methods.contains_key(name) {
220 dev_log!("extensions", "warn: API method already registered: {}", name);
221 }
222
223 methods.insert(
224 name.to_string(),
225 APIMethodInfo {
226 name:name.to_string(),
227 description:description.to_string(),
228 parameters,
229 returns,
230 is_async,
231 call_count:0,
232 total_time_us:0,
233 },
234 );
235
236 dev_log!("extensions", "Registered API method: {}", name);
237
238 Ok(())
239 }
240
241 pub async fn create_context(&self, extension_id:&str) -> Result<APIContext> {
243 let context_id = format!("{}-{}", extension_id, uuid::Uuid::new_v4());
244
245 let context = APIContext {
246 extension_id:extension_id.to_string(),
247
248 context_id:context_id.clone(),
249
250 workspace_folder:None,
251
252 active_editor:None,
253
254 selections:Vec::new(),
255
256 created_at:std::time::SystemTime::now()
257 .duration_since(std::time::UNIX_EPOCH)
258 .map(|d| d.as_secs())
259 .unwrap_or(0),
260 };
261
262 let mut contexts = self.contexts.write().await;
263
264 contexts.insert(context_id.clone(), context.clone());
265
266 let mut stats = self.stats.write().await;
268
269 stats.active_contexts = contexts.len();
270
271 dev_log!("extensions", "Created API context for extension: {}", extension_id);
272
273 Ok(context)
274 }
275
276 pub async fn get_context(&self, context_id:&str) -> Option<APIContext> {
278 self.contexts.read().await.get(context_id).cloned()
279 }
280
281 pub async fn update_context(&self, context:APIContext) -> Result<()> {
283 let mut contexts = self.contexts.write().await;
284
285 contexts.insert(context.context_id.clone(), context);
286
287 Ok(())
288 }
289
290 pub async fn remove_context(&self, context_id:&str) -> Result<bool> {
292 let mut contexts = self.contexts.write().await;
293
294 let removed = contexts.remove(context_id).is_some();
295
296 if removed {
297 let mut stats = self.stats.write().await;
298
299 stats.active_contexts = contexts.len();
300 }
301
302 Ok(removed)
303 }
304
305 pub async fn Call(&self, request:APICallRequest) -> Result<APICallResponse> {
307 let start = std::time::Instant::now();
308
309 dev_log!(
310 "extensions",
311 "Handling API call: {} from extension {}",
312 request.api_method,
313 request.extension_id
314 );
315
316 let exists = {
318 let methods = self.api_methods.read().await;
319
320 methods.contains_key(&request.api_method)
321 };
322
323 if !exists {
324 return Ok(APICallResponse {
325 success:false,
326 data:None,
327 error:Some(format!("API method not found: {}", request.api_method)),
328 correlation_id:request.correlation_id,
329 });
330 }
331
332 let result = self
335 .execute_method(&request.extension_id, &request.api_method, &request.arguments)
336 .await;
337
338 let elapsed_us = start.elapsed().as_micros() as u64;
339
340 let mut stats = self.stats.write().await;
342
343 stats.total_calls += 1;
344
345 stats.total_calls += 1;
346
347 if exists {
348 stats.successful_calls += 1;
349
350 stats.avg_latency_us =
352 (stats.avg_latency_us * (stats.successful_calls - 1) + elapsed_us) / stats.successful_calls;
353 }
354
355 {
357 let mut methods = self.api_methods.write().await;
358
359 if let Some(method) = methods.get_mut(&request.api_method) {
360 method.call_count += 1;
361
362 method.total_time_us += elapsed_us;
363 }
364 }
365
366 dev_log!("extensions", "API call {} completed in {}µs", request.api_method, elapsed_us);
367
368 match result {
369 Ok(data) => {
370 Ok(
371 APICallResponse {
372 success:true,
373 data:Some(data),
374 error:None,
375 correlation_id:request.correlation_id,
376 },
377 )
378 },
379
380 Err(e) => {
381 Ok(APICallResponse {
382 success:false,
383 data:None,
384 error:Some(e.to_string()),
385 correlation_id:request.correlation_id,
386 })
387 },
388 }
389 }
390
391 async fn execute_method(
393 &self,
394
395 _extension_id:&str,
396
397 _method_name:&str,
398
399 _arguments:&[serde_json::Value],
400 ) -> Result<serde_json::Value> {
401 Ok(serde_json::Value::Null)
410 }
411
412 pub async fn stats(&self) -> APIStats { self.stats.read().await.clone() }
414
415 pub async fn get_methods(&self) -> Vec<APIMethodInfo> { self.api_methods.read().await.values().cloned().collect() }
417
418 pub async fn unregister_method(&self, name:&str) -> Result<bool> {
420 let mut methods = self.api_methods.write().await;
421
422 let removed = methods.remove(name).is_some();
423
424 if removed {
425 dev_log!("extensions", "Unregistered API method: {}", name);
426 }
427
428 Ok(removed)
429 }
430}
431
432impl Default for APIBridgeImpl {
433 fn default() -> Self { Self::new() }
434}
435
436#[cfg(test)]
437mod tests {
438
439 use super::*;
440
441 #[tokio::test]
442 async fn test_api_bridge_creation() {
443 let bridge = APIBridgeImpl::new();
444
445 let stats = bridge.stats().await;
446
447 assert_eq!(stats.total_calls, 0);
448
449 assert_eq!(stats.successful_calls, 0);
450 }
451
452 #[tokio::test]
453 async fn test_context_creation() {
454 let bridge = APIBridgeImpl::new();
455
456 let context = bridge.create_context("test.ext").await.unwrap();
457
458 assert_eq!(context.extension_id, "test.ext");
459
460 assert!(!context.context_id.is_empty());
461 }
462
463 #[tokio::test]
464 async fn test_method_registration() {
465 let bridge = APIBridgeImpl::new();
466
467 let result:Result<()> = bridge.register_method("test.method", "Test method", None, None, false).await;
468
469 assert!(result.is_ok());
470
471 let methods:Vec<APIMethodInfo> = bridge.get_methods().await;
472
473 assert!(methods.iter().any(|m| m.name == "test.method"));
474 }
475
476 #[tokio::test]
477 async fn test_api_call_request() {
478 let request = APICallRequest {
479 extension_id:"test.ext".to_string(),
480
481 api_method:"test.method".to_string(),
482
483 arguments:vec![serde_json::json!("arg1")],
484
485 correlation_id:Some("test-id".to_string()),
486 };
487
488 assert_eq!(request.extension_id, "test.ext");
489
490 assert_eq!(request.api_method, "test.method");
491
492 assert_eq!(request.arguments.len(), 1);
493 }
494
495 #[test]
496 fn test_selection_default() {
497 let selection = Selection::default();
498
499 assert_eq!(selection.start_line, 0);
500
501 assert_eq!(selection.end_line, 0);
502 }
503}