agentmux_srv\backend\reactive/
sanitize.rs

1// Copyright 2025-2026, AgentMux Corp.
2// SPDX-License-Identifier: Apache-2.0
3
4
5use super::{MAX_MESSAGE_LENGTH, TRUNCATION_SUFFIX};
6
7/// Sanitize a message by removing dangerous escape sequences and control characters.
8///
9/// 1. Removes ANSI escape sequences
10/// 2. Removes OSC sequences (terminal commands)
11/// 3. Removes CSI sequences
12/// 4. Removes control characters except \n, \t, \r
13/// 5. Truncates to MAX_MESSAGE_LENGTH with UTF-8 safety
14pub fn sanitize_message(msg: &str) -> String {
15    let mut result = String::with_capacity(msg.len());
16
17    let bytes = msg.as_bytes();
18    let len = bytes.len();
19    let mut i = 0;
20
21    while i < len {
22        let b = bytes[i];
23
24        // Check for ESC sequences
25        if b == 0x1b && i + 1 < len {
26            let next = bytes[i + 1];
27
28            // CSI sequence: ESC [ ... <final byte>
29            if next == b'[' {
30                i += 2;
31                while i < len && !(bytes[i] >= 0x40 && bytes[i] <= 0x7e) {
32                    i += 1;
33                }
34                if i < len {
35                    i += 1; // skip final byte
36                }
37                continue;
38            }
39
40            // OSC sequence: ESC ] ... BEL
41            if next == b']' {
42                i += 2;
43                while i < len && bytes[i] != 0x07 {
44                    // Also check for ST (ESC \)
45                    if bytes[i] == 0x1b && i + 1 < len && bytes[i + 1] == b'\\' {
46                        i += 2;
47                        break;
48                    }
49                    i += 1;
50                }
51                if i < len && bytes[i] == 0x07 {
52                    i += 1;
53                }
54                continue;
55            }
56
57            // Other ESC sequences (2-byte)
58            i += 2;
59            continue;
60        }
61
62        // Remove control characters except whitespace
63        if b < 0x20 && b != b'\n' && b != b'\r' && b != b'\t' {
64            i += 1;
65            continue;
66        }
67
68        // DEL character
69        if b == 0x7f {
70            i += 1;
71            continue;
72        }
73
74        // Keep printable characters and valid UTF-8
75        if b < 0x80 {
76            result.push(b as char);
77            i += 1;
78        } else {
79            // UTF-8 multi-byte: determine sequence length
80            let seq_len = if b >= 0xF0 {
81                4
82            } else if b >= 0xE0 {
83                3
84            } else if b >= 0xC0 {
85                2
86            } else {
87                // Invalid continuation byte, skip
88                i += 1;
89                continue;
90            };
91
92            if i + seq_len <= len {
93                let s = std::str::from_utf8(&bytes[i..i + seq_len]);
94                if let Ok(valid) = s {
95                    result.push_str(valid);
96                }
97                i += seq_len;
98            } else {
99                // Incomplete sequence
100                i += 1;
101            }
102        }
103    }
104
105    // Truncate to max length, preserving UTF-8
106    if result.len() > MAX_MESSAGE_LENGTH {
107        let suffix_len = TRUNCATION_SUFFIX.len();
108        let target = MAX_MESSAGE_LENGTH - suffix_len;
109        // Find a valid UTF-8 boundary
110        let mut end = target;
111        while end > 0 && !result.is_char_boundary(end) {
112            end -= 1;
113        }
114        result.truncate(end);
115        result.push_str(TRUNCATION_SUFFIX);
116    }
117
118    result
119}
120
121/// Validate an agent ID.
122///
123/// Must be 1-64 characters, only letters, digits, underscore, and hyphen.
124pub fn validate_agent_id(agent_id: &str) -> bool {
125    if agent_id.is_empty() || agent_id.len() > 64 {
126        return false;
127    }
128    agent_id
129        .bytes()
130        .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'-')
131}
132
133/// Format a message with optional source agent prefix.
134pub fn format_injected_message(msg: &str, source_agent: Option<&str>, include_source: bool) -> String {
135    if include_source {
136        if let Some(source) = source_agent {
137            if !source.is_empty() {
138                return format!("@{}: {}", source, msg);
139            }
140        }
141    }
142    msg.to_string()
143}
144
145/// Validate an AgentMux URL for SSRF protection.
146///
147/// Only allows https:// or http://localhost/127.0.0.1/::1.
148#[allow(dead_code)]
149pub fn validate_agentmux_url(url_str: &str) -> Result<(), String> {
150    if url_str.is_empty() {
151        return Err("URL is empty".to_string());
152    }
153
154    // Parse URL
155    if let Some(scheme_end) = url_str.find("://") {
156        let scheme = &url_str[..scheme_end];
157        let rest = &url_str[scheme_end + 3..];
158
159        match scheme {
160            "https" => Ok(()),
161            "http" => {
162                // Extract host (before port or path)
163                let authority = rest.split('/').next().unwrap_or("");
164                let host = if authority.starts_with('[') {
165                    // IPv6 bracketed: [::1]:port
166                    authority.split(']').next().unwrap_or("")
167                } else {
168                    authority.split(':').next().unwrap_or("")
169                };
170                // Normalize: strip brackets for comparison
171                let host_clean = host.trim_start_matches('[').trim_end_matches(']');
172
173                match host_clean {
174                    "localhost" | "127.0.0.1" | "::1" => Ok(()),
175                    _ => Err(format!(
176                        "http URLs only allowed for localhost, got host: {}",
177                        host_clean
178                    )),
179                }
180            }
181            _ => Err(format!("unsupported URL scheme: {}", scheme)),
182        }
183    } else {
184        Err("invalid URL: missing scheme".to_string())
185    }
186}