1use std::path::PathBuf;
22use std::process::Stdio;
23
24use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
25use tokio::process::Command;
26use tokio::sync::{mpsc, oneshot};
27
28use super::translator::claude::ClaudeTranslator;
29use super::translator::Translator as _;
30use super::types::{AgentEvent, AgentRef, AgentRunResult, AgentTask};
31
32const ENV_CLAUDE_BIN: &str = "AGENTMUX_CLAUDE_BIN";
35
36const DEFAULT_CLAUDE_BIN: &str = "claude";
37
38pub struct AgentRunHandle {
49 pub instance_id: String,
50 pub final_result: oneshot::Receiver<Result<AgentRunResult, String>>,
51}
52
53#[derive(Debug, thiserror::Error)]
55pub enum AgentError {
56 #[error("agent runner: invalid AgentRef: {0}")]
57 InvalidRef(String),
58 #[error("agent runner: spawn failed: {0}")]
59 Spawn(String),
60}
61
62pub async fn run_agent(
78 agent_ref: AgentRef,
79 task: AgentTask,
80 tx: mpsc::UnboundedSender<AgentEvent>,
81) -> Result<AgentRunHandle, AgentError> {
82 let bin = std::env::var(ENV_CLAUDE_BIN)
83 .unwrap_or_else(|_| DEFAULT_CLAUDE_BIN.to_string());
84 run_agent_with_bin(&bin, agent_ref, task, tx).await
85}
86
87pub(crate) async fn run_agent_with_bin(
94 bin: &str,
95 agent_ref: AgentRef,
96 task: AgentTask,
97 tx: mpsc::UnboundedSender<AgentEvent>,
98) -> Result<AgentRunHandle, AgentError> {
99 let working_dir = if agent_ref.working_directory.is_empty() {
100 std::env::current_dir()
101 .map_err(|e| AgentError::Spawn(format!("cwd: {e}")))?
102 } else {
103 PathBuf::from(&agent_ref.working_directory)
104 };
105
106 let mut cmd = Command::new(bin);
114 cmd.arg("--print")
115 .arg("--output-format=stream-json")
116 .arg("--verbose")
117 .arg("--include-partial-messages");
118 if let Some(n) = task.max_turns {
123 cmd.arg("--max-turns").arg(n.to_string());
124 }
125 let mut child = cmd
126 .arg(&task.prompt)
127 .current_dir(&working_dir)
128 .stdin(Stdio::null())
129 .stdout(Stdio::piped())
130 .stderr(Stdio::piped())
131 .kill_on_drop(true)
132 .spawn()
133 .map_err(|e| AgentError::Spawn(format!("spawn `{bin}`: {e}")))?;
134
135 let stdout = child
136 .stdout
137 .take()
138 .ok_or_else(|| AgentError::Spawn("claude stdout pipe missing".to_string()))?;
139 let stderr = child
140 .stderr
141 .take()
142 .ok_or_else(|| AgentError::Spawn("claude stderr pipe missing".to_string()))?;
143
144 let instance_id = format!("drone-agent-{}", uuid::Uuid::new_v4());
145 let (result_tx, result_rx) = oneshot::channel();
146
147 tokio::spawn(async move {
156 const STDERR_CAP: usize = 64 * 1024;
157 let mut buf = Vec::with_capacity(4096);
158 let mut reader = BufReader::new(stderr);
159 let mut sink = [0u8; 4096];
160 loop {
162 if buf.len() < STDERR_CAP {
163 let space = STDERR_CAP - buf.len();
164 let take = space.min(sink.len());
165 match reader.read(&mut sink[..take]).await {
166 Ok(0) => break,
167 Ok(n) => buf.extend_from_slice(&sink[..n]),
168 Err(_) => break,
169 }
170 } else {
171 match reader.read(&mut sink).await {
173 Ok(0) => break,
174 Ok(_) => {}
175 Err(_) => break,
176 }
177 }
178 }
179 let _ = buf;
181 });
182
183 tokio::spawn(async move {
184 let result = drain_and_collect(stdout, &tx, &mut child).await;
185 let _ = result_tx.send(result);
186 });
187
188 Ok(AgentRunHandle {
189 instance_id,
190 final_result: result_rx,
191 })
192}
193
194async fn drain_and_collect(
203 stdout: tokio::process::ChildStdout,
204 tx: &mpsc::UnboundedSender<AgentEvent>,
205 child: &mut tokio::process::Child,
206) -> Result<AgentRunResult, String> {
207 let result = drain_async_reader(BufReader::new(stdout), tx).await;
208
209 let exit = child.wait().await.map_err(|e| format!("wait: {e}"))?;
211
212 match result {
213 Ok(mut accumulated) if exit.success() => {
214 if accumulated.response.is_empty() && accumulated.transcript.is_empty() {
218 return Err("claude exited 0 but stream produced no Done event".to_string());
219 }
220 accumulated.transcript.shrink_to_fit();
221 Ok(accumulated)
222 }
223 Ok(_) => Err(format!(
224 "claude exited with status {exit} but stream emitted no error"
225 )),
226 Err(e) => Err(e),
227 }
228}
229
230pub(crate) async fn drain_async_reader<R: tokio::io::AsyncBufRead + Unpin>(
237 mut reader: R,
238 tx: &mpsc::UnboundedSender<AgentEvent>,
239) -> Result<AgentRunResult, String> {
240 let mut translator = ClaudeTranslator::new();
241 let mut accumulated = AgentRunResult::default();
242 let mut line = String::new();
243 loop {
244 line.clear();
245 let n = reader
246 .read_line(&mut line)
247 .await
248 .map_err(|e| format!("stdout read: {e}"))?;
249 if n == 0 {
250 break; }
252 let trimmed = line.trim_end_matches(['\n', '\r']);
253 if !trimmed.starts_with('{') {
254 continue;
255 }
256 let Ok(frame) = serde_json::from_str::<serde_json::Value>(trimmed) else {
257 continue;
258 };
259 for event in translator.translate(frame) {
260 match &event {
263 AgentEvent::Cost { cost_usd, tokens } => {
264 accumulated.cost_usd = *cost_usd;
265 accumulated.tokens = tokens.clone();
266 }
267 AgentEvent::Done {
268 response,
269 transcript,
270 } => {
271 accumulated.response = response.clone();
272 accumulated.transcript = transcript.clone();
273 }
274 _ => {}
275 }
276 let _ = tx.send(event);
280 }
281 }
282 Ok(accumulated)
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use serde_json::json;
289 use tokio::io::AsyncWriteExt;
290
291 fn synthetic_stream(prompt_reply: &str, cost: f64) -> Vec<u8> {
294 let mut s = String::new();
295 for ch in prompt_reply.chars() {
296 s.push_str(&format!(
297 r#"{{"type":"stream_event","event":{{"type":"content_block_delta","delta":{{"type":"text_delta","text":"{ch}"}}}}}}"#,
298 ));
299 s.push('\n');
300 }
301 s.push_str(&format!(
302 r#"{{"type":"assistant","message":{{"content":[{{"type":"text","text":"{prompt_reply}"}}]}}}}
303"#
304 ));
305 s.push_str(&format!(
306 r#"{{"type":"result","cost_usd":{cost},"usage":{{"input_tokens":10,"output_tokens":5,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}},"result":"{prompt_reply}"}}
307"#
308 ));
309 s.into_bytes()
310 }
311
312 #[tokio::test]
313 async fn drain_async_reader_accumulates_cost_and_done() {
314 let bytes = synthetic_stream("hello", 0.001);
315 let (mut w, r) = tokio::io::duplex(4096);
316 tokio::spawn(async move {
317 w.write_all(&bytes).await.unwrap();
318 w.shutdown().await.unwrap();
319 });
320
321 let (tx, mut rx) = mpsc::unbounded_channel();
322 let result = drain_async_reader(BufReader::new(r), &tx)
323 .await
324 .expect("drain ok");
325
326 assert_eq!(result.response, "hello");
327 assert_eq!(result.cost_usd, 0.001);
328 assert_eq!(result.tokens.input, 10);
329 assert_eq!(result.tokens.output, 5);
330 assert_eq!(result.transcript.len(), 1);
332
333 drop(tx);
335 let mut evs = Vec::new();
336 while let Some(e) = rx.recv().await {
337 evs.push(e);
338 }
339 assert_eq!(evs.len(), 7, "got events: {evs:?}");
340 match &evs[evs.len() - 1] {
341 AgentEvent::Done { .. } => {}
342 other => panic!("expected last event Done, got {other:?}"),
343 }
344 }
345
346 #[tokio::test]
347 async fn drain_async_reader_skips_non_json_lines() {
348 let mut bytes: Vec<u8> = b"Reading config...\n".to_vec();
352 bytes.extend_from_slice(&synthetic_stream("ok", 0.0));
353 bytes.extend_from_slice(b"\n");
354
355 let (mut w, r) = tokio::io::duplex(4096);
356 tokio::spawn(async move {
357 w.write_all(&bytes).await.unwrap();
358 w.shutdown().await.unwrap();
359 });
360
361 let (tx, _rx) = mpsc::unbounded_channel();
362 let result = drain_async_reader(BufReader::new(r), &tx)
363 .await
364 .expect("drain ok");
365 assert_eq!(result.response, "ok");
366 }
367
368 #[tokio::test]
369 async fn drain_async_reader_returns_empty_on_no_stream() {
370 let (mut w, r) = tokio::io::duplex(4096);
371 tokio::spawn(async move {
372 w.shutdown().await.unwrap();
374 });
375
376 let (tx, _rx) = mpsc::unbounded_channel();
377 let result = drain_async_reader(BufReader::new(r), &tx)
378 .await
379 .expect("drain ok");
380 assert!(result.response.is_empty());
385 assert_eq!(result.cost_usd, 0.0);
386 }
387
388 #[tokio::test]
389 async fn drain_async_reader_handles_multi_line_chunks() {
390 let bytes = synthetic_stream("multi", 0.01);
394 let (mut w, r) = tokio::io::duplex(4096);
395 tokio::spawn(async move {
396 for chunk in bytes.chunks(7) {
398 w.write_all(chunk).await.unwrap();
399 }
400 w.shutdown().await.unwrap();
401 });
402
403 let (tx, _rx) = mpsc::unbounded_channel();
404 let result = drain_async_reader(BufReader::new(r), &tx)
405 .await
406 .expect("drain ok");
407 assert_eq!(result.response, "multi");
408 }
409
410 #[tokio::test]
411 async fn drain_handles_malformed_json_gracefully() {
412 let mut bytes: Vec<u8> =
413 b"{this is not valid json\n{\"type\":\"unknown\"}\n".to_vec();
414 bytes.extend_from_slice(&synthetic_stream("recovered", 0.0));
415
416 let (mut w, r) = tokio::io::duplex(4096);
417 tokio::spawn(async move {
418 w.write_all(&bytes).await.unwrap();
419 w.shutdown().await.unwrap();
420 });
421
422 let (tx, _rx) = mpsc::unbounded_channel();
423 let result = drain_async_reader(BufReader::new(r), &tx)
424 .await
425 .expect("drain ok");
426 assert_eq!(result.response, "recovered");
427 }
428
429 #[tokio::test]
430 #[ignore = "requires `claude` CLI on PATH; run manually for end-to-end"]
431 async fn run_agent_end_to_end_with_real_claude() {
432 let (tx, mut rx) = mpsc::unbounded_channel();
436 let handle = run_agent(
437 AgentRef::default(),
438 AgentTask {
439 prompt: "What is 2+2? Respond with just the number.".to_string(),
440 context: serde_json::Map::new(),
441 max_turns: None,
442 },
443 tx,
444 )
445 .await
446 .expect("spawn ok");
447
448 while let Some(_ev) = rx.recv().await {}
450
451 let result = handle
452 .final_result
453 .await
454 .expect("oneshot ok")
455 .expect("agent run ok");
456 assert!(result.response.contains('4'), "got: {}", result.response);
457 assert!(result.cost_usd > 0.0);
458 }
459
460 #[tokio::test]
461 async fn run_agent_with_bin_surfaces_spawn_failure() {
462 let (tx, _rx) = mpsc::unbounded_channel();
467 let result = run_agent_with_bin(
468 "/definitely/does/not/exist/claude-xyz-test",
469 AgentRef::default(),
470 AgentTask {
471 prompt: "hi".to_string(),
472 context: serde_json::Map::new(),
473 max_turns: None,
474 },
475 tx,
476 )
477 .await;
478 match result {
479 Err(AgentError::Spawn(msg)) => {
480 assert!(
481 msg.contains("spawn") || msg.contains("does/not/exist"),
482 "spawn error message should reference the failure; got: {msg}"
483 );
484 }
485 Err(other) => panic!("expected Spawn error, got: {other}"),
486 Ok(_) => panic!("expected Spawn error, got Ok(handle)"),
487 }
488 }
489
490 #[test]
491 fn agent_task_max_turns_field_round_trips() {
492 let task = AgentTask {
499 prompt: "x".into(),
500 context: serde_json::Map::new(),
501 max_turns: Some(7),
502 };
503 let v = serde_json::to_value(&task).unwrap();
504 assert_eq!(v["maxTurns"], json!(7));
505 let back: AgentTask = serde_json::from_value(v).unwrap();
506 assert_eq!(back.max_turns, Some(7));
507 }
508
509 #[test]
510 fn agent_run_result_has_sensible_defaults() {
511 let r = AgentRunResult::default();
512 assert_eq!(r.response, "");
513 assert_eq!(r.cost_usd, 0.0);
514 assert!(r.transcript.is_empty());
515 let _ = json!(r); }
517}