1use std::collections::HashMap;
6use std::sync::{Mutex, OnceLock};
7use std::time::{Duration, Instant};
8
9use super::sanitize::{format_injected_message, sanitize_message, validate_agent_id};
10use super::types::*;
11use super::{now_unix_millis, sha256_hex, AUDIT_LOG_MAX, RATE_LIMIT_MAX};
12
13pub(super) struct RateLimiter {
16 tokens: u32,
17 max_tokens: u32,
18 last_refill: Instant,
19}
20
21impl RateLimiter {
22 pub(super) fn new(max_tokens: u32) -> Self {
23 Self {
24 tokens: max_tokens,
25 max_tokens,
26 last_refill: Instant::now(),
27 }
28 }
29
30 pub(super) fn check(&mut self) -> bool {
31 let now = Instant::now();
32 let elapsed = now.duration_since(self.last_refill);
33 if elapsed >= Duration::from_secs(1) {
34 self.tokens = self.max_tokens;
35 self.last_refill = now;
36 }
37 if self.tokens > 0 {
38 self.tokens -= 1;
39 true
40 } else {
41 false
42 }
43 }
44}
45
46pub struct Handler {
53 agent_to_block: HashMap<String, String>,
54 block_to_agent: HashMap<String, String>,
55 agent_info: HashMap<String, AgentRegistration>,
56 input_sender: Option<InputSender>,
57 audit_log: Vec<AuditLogEntry>,
58 rate_limiter: RateLimiter,
59 include_source_in_message: bool,
60}
61
62impl Handler {
63 pub fn new() -> Self {
66 Self {
67 agent_to_block: HashMap::new(),
68 block_to_agent: HashMap::new(),
69 agent_info: HashMap::new(),
70 input_sender: None,
71 audit_log: Vec::with_capacity(AUDIT_LOG_MAX),
72 rate_limiter: RateLimiter::new(RATE_LIMIT_MAX),
73 include_source_in_message: false,
74 }
75 }
76
77 pub fn set_input_sender(&mut self, sender: InputSender) {
79 self.input_sender = Some(sender);
80 }
81
82 #[allow(dead_code)]
84 pub fn set_include_source(&mut self, include: bool) {
85 self.include_source_in_message = include;
86 }
87
88 pub fn register_agent(
90 &mut self,
91 agent_id: &str,
92 block_id: &str,
93 tab_id: Option<&str>,
94 ) -> Result<(), String> {
95 if !validate_agent_id(agent_id) {
96 return Err(format!("invalid agent ID: {}", agent_id));
97 }
98
99 let agent_key = agent_id.to_lowercase();
100
101 if let Some(old_block) = self.agent_to_block.remove(&agent_key) {
103 self.block_to_agent.remove(&old_block);
104 }
105
106 if let Some(old_agent) = self.block_to_agent.remove(block_id) {
108 self.agent_to_block.remove(&old_agent);
109 self.agent_info.remove(&old_agent);
110 }
111
112 let now = now_unix_millis();
113 self.agent_to_block
114 .insert(agent_key.clone(), block_id.to_string());
115 self.block_to_agent
116 .insert(block_id.to_string(), agent_key.clone());
117 self.agent_info.insert(
118 agent_key.clone(),
119 AgentRegistration {
120 agent_id: agent_id.to_string(),
121 block_id: block_id.to_string(),
122 tab_id: tab_id.map(|s| s.to_string()),
123 registered_at: now,
124 last_seen: now,
125 },
126 );
127
128 Ok(())
129 }
130
131 pub fn unregister_agent(&mut self, agent_id: &str) {
133 let agent_key = agent_id.to_lowercase();
134 if let Some(block_id) = self.agent_to_block.remove(&agent_key) {
135 self.block_to_agent.remove(&block_id);
136 }
137 self.agent_info.remove(&agent_key);
138 }
139
140 pub fn unregister_block(&mut self, block_id: &str) {
142 if let Some(agent_id) = self.block_to_agent.remove(block_id) {
143 self.agent_to_block.remove(&agent_id);
144 self.agent_info.remove(&agent_id);
145 }
146 }
147
148 #[allow(dead_code)]
150 pub fn update_last_seen(&mut self, agent_id: &str) {
151 if let Some(info) = self.agent_info.get_mut(&agent_id.to_lowercase()) {
152 info.last_seen = now_unix_millis();
153 }
154 }
155
156 pub fn get_agent(&self, agent_id: &str) -> Option<&AgentRegistration> {
158 self.agent_info.get(&agent_id.to_lowercase())
159 }
160
161 #[allow(dead_code)]
163 pub fn get_agent_by_block(&self, block_id: &str) -> Option<&AgentRegistration> {
164 self.block_to_agent
165 .get(block_id)
166 .and_then(|agent_id| self.agent_info.get(agent_id))
167 }
168
169 pub fn list_agents(&self) -> Vec<AgentRegistration> {
171 self.agent_info.values().cloned().collect()
172 }
173
174 pub fn inject_message(&mut self, mut req: InjectionRequest) -> InjectionResponse {
180 let now = now_unix_millis();
181
182 if req.request_id.is_none() || req.request_id.as_deref() == Some("") {
184 req.request_id = Some(uuid::Uuid::new_v4().to_string());
185 }
186 let request_id = req.request_id.clone().unwrap_or_default();
187
188 if !self.rate_limiter.check() {
190 return InjectionResponse {
191 success: false,
192 request_id,
193 block_id: None,
194 error: Some("rate limit exceeded".to_string()),
195 timestamp: now,
196 };
197 }
198
199 if !validate_agent_id(&req.target_agent) {
201 return InjectionResponse {
202 success: false,
203 request_id,
204 block_id: None,
205 error: Some(format!("invalid agent ID: {}", req.target_agent)),
206 timestamp: now,
207 };
208 }
209
210 let sanitized = sanitize_message(&req.message);
212
213 let block_id = match self.agent_to_block.get(&req.target_agent.to_lowercase()) {
215 Some(id) => id.clone(),
216 None => {
217 let err = format!("agent not found: {}", req.target_agent);
218 self.log_audit(
219 req.source_agent.as_deref(),
220 &req.target_agent,
221 "",
222 &sanitized,
223 false,
224 Some(&err),
225 &request_id,
226 );
227 return InjectionResponse {
228 success: false,
229 request_id,
230 block_id: None,
231 error: Some(err),
232 timestamp: now,
233 };
234 }
235 };
236
237 let final_msg = format_injected_message(
239 &sanitized,
240 req.source_agent.as_deref(),
241 self.include_source_in_message,
242 );
243
244 let sender = match &self.input_sender {
246 Some(s) => s.clone(),
247 None => {
248 let err = "input sender not configured".to_string();
249 self.log_audit(
250 req.source_agent.as_deref(),
251 &req.target_agent,
252 &block_id,
253 &sanitized,
254 false,
255 Some(&err),
256 &request_id,
257 );
258 return InjectionResponse {
259 success: false,
260 request_id,
261 block_id: Some(block_id),
262 error: Some(err),
263 timestamp: now,
264 };
265 }
266 };
267
268 let _ = sender(&block_id, b"\r");
273 let payload = format!("{}\r", final_msg);
274 tracing::info!(
275 target_agent = %req.target_agent,
276 block_id = %block_id,
277 msg_len = payload.len(),
278 "inject: sending payload to PTY"
279 );
280 if let Err(e) = sender(&block_id, payload.as_bytes()) {
281 tracing::error!(
282 target_agent = %req.target_agent,
283 block_id = %block_id,
284 error = %e,
285 "inject: sender failed"
286 );
287 self.log_audit(
288 req.source_agent.as_deref(),
289 &req.target_agent,
290 &block_id,
291 &sanitized,
292 false,
293 Some(&e),
294 &request_id,
295 );
296 return InjectionResponse {
297 success: false,
298 request_id,
299 block_id: Some(block_id),
300 error: Some(e),
301 timestamp: now,
302 };
303 }
304
305 let sender_enter = sender.clone();
307 let block_id_enter = block_id.clone();
308 tokio::spawn(async move {
309 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
310 let _ = sender_enter(&block_id_enter, b"\r");
311 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
312 let _ = sender_enter(&block_id_enter, b"\r");
313 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
314 let _ = sender_enter(&block_id_enter, b"\r");
315 });
316
317 self.log_audit(
319 req.source_agent.as_deref(),
320 &req.target_agent,
321 &block_id,
322 &sanitized,
323 true,
324 None,
325 &request_id,
326 );
327
328 InjectionResponse {
329 success: true,
330 request_id,
331 block_id: Some(block_id),
332 error: None,
333 timestamp: now,
334 }
335 }
336
337 pub fn get_audit_log(&self, limit: usize) -> Vec<AuditLogEntry> {
339 let start = if self.audit_log.len() > limit {
340 self.audit_log.len() - limit
341 } else {
342 0
343 };
344 let mut entries: Vec<_> = self.audit_log[start..].to_vec();
345 entries.reverse();
346 entries
347 }
348
349 #[allow(clippy::too_many_arguments)]
351 pub(super) fn log_audit(
352 &mut self,
353 source_agent: Option<&str>,
354 target_agent: &str,
355 block_id: &str,
356 message: &str,
357 success: bool,
358 error_message: Option<&str>,
359 request_id: &str,
360 ) {
361 let entry = AuditLogEntry {
362 timestamp: now_unix_millis(),
363 source_agent: source_agent.map(|s| s.to_string()),
364 target_agent: target_agent.to_string(),
365 block_id: block_id.to_string(),
366 message_hash: sha256_hex(message),
367 message_length: message.len(),
368 success,
369 error_message: error_message.map(|s| s.to_string()),
370 request_id: request_id.to_string(),
371 };
372
373 if self.audit_log.len() >= AUDIT_LOG_MAX {
374 self.audit_log.remove(0);
375 }
376 self.audit_log.push(entry);
377 }
378}
379
380impl Default for Handler {
381 fn default() -> Self {
382 Self::new()
383 }
384}
385
386pub struct ReactiveHandler {
390 inner: Mutex<Handler>,
391}
392
393impl ReactiveHandler {
394 pub fn new() -> Self {
395 Self {
396 inner: Mutex::new(Handler::new()),
397 }
398 }
399
400 pub fn set_input_sender(&self, sender: InputSender) {
401 self.inner.lock().unwrap().set_input_sender(sender);
402 }
403
404 #[allow(dead_code)]
405 pub fn set_include_source(&self, include: bool) {
406 self.inner.lock().unwrap().set_include_source(include);
407 }
408
409 pub fn register_agent(
410 &self,
411 agent_id: &str,
412 block_id: &str,
413 tab_id: Option<&str>,
414 ) -> Result<(), String> {
415 self.inner
416 .lock()
417 .unwrap()
418 .register_agent(agent_id, block_id, tab_id)
419 }
420
421 pub fn unregister_agent(&self, agent_id: &str) {
422 self.inner.lock().unwrap().unregister_agent(agent_id);
423 }
424
425 pub fn unregister_block(&self, block_id: &str) {
426 self.inner.lock().unwrap().unregister_block(block_id);
427 }
428
429 #[allow(dead_code)]
430 pub fn update_last_seen(&self, agent_id: &str) {
431 self.inner.lock().unwrap().update_last_seen(agent_id);
432 }
433
434 pub fn get_agent(&self, agent_id: &str) -> Option<AgentRegistration> {
435 self.inner.lock().unwrap().get_agent(agent_id).cloned()
436 }
437
438 #[allow(dead_code)]
439 pub fn get_agent_by_block(&self, block_id: &str) -> Option<AgentRegistration> {
440 self.inner
441 .lock()
442 .unwrap()
443 .get_agent_by_block(block_id)
444 .cloned()
445 }
446
447 pub fn list_agents(&self) -> Vec<AgentRegistration> {
448 self.inner.lock().unwrap().list_agents()
449 }
450
451 pub fn inject_message(&self, req: InjectionRequest) -> InjectionResponse {
452 self.inner.lock().unwrap().inject_message(req)
453 }
454
455 pub fn get_audit_log(&self, limit: usize) -> Vec<AuditLogEntry> {
456 self.inner.lock().unwrap().get_audit_log(limit)
457 }
458}
459
460impl Default for ReactiveHandler {
461 fn default() -> Self {
462 Self::new()
463 }
464}
465
466static GLOBAL_HANDLER: OnceLock<ReactiveHandler> = OnceLock::new();
468
469pub fn get_global_handler() -> &'static ReactiveHandler {
471 GLOBAL_HANDLER.get_or_init(ReactiveHandler::new)
472}