diff --git a/src/agent/mod.rs b/src/agent/mod.rs new file mode 100644 index 0000000..abfb265 --- /dev/null +++ b/src/agent/mod.rs @@ -0,0 +1,63 @@ +pub mod planner; + +use crate::pipe::{AgentMessage, BrowserMessage, BrowserPipeTool, PipeError, Transport}; + +pub fn execute_task( + transport: &T, + browser_tool: &BrowserPipeTool, + instruction: &str, +) -> Result { + let plan = planner::plan_instruction(instruction) + .map_err(|err| PipeError::Protocol(err.to_string()))?; + + for step in &plan.steps { + transport.send(&AgentMessage::LogEntry { + level: "info".to_string(), + message: step.log_message.clone(), + })?; + + let result = browser_tool.invoke( + step.action.clone(), + step.params.clone(), + &step.expected_domain, + )?; + if !result.success { + return Err(PipeError::Protocol(format!( + "browser action failed: {}", + result.data + ))); + } + } + + Ok(plan.summary) +} + +pub fn handle_browser_message( + transport: &T, + browser_tool: &BrowserPipeTool, + message: BrowserMessage, +) -> Result<(), PipeError> { + match message { + BrowserMessage::SubmitTask { instruction } => { + let completion = match execute_task(transport, browser_tool, &instruction) { + Ok(summary) => AgentMessage::TaskComplete { + success: true, + summary, + }, + Err(err) => AgentMessage::TaskComplete { + success: false, + summary: err.to_string(), + }, + }; + transport.send(&completion) + } + BrowserMessage::Init { .. } => { + eprintln!("ignoring duplicate init after handshake"); + Ok(()) + } + BrowserMessage::Response { seq, .. } => { + eprintln!("ignoring unsolicited response: seq={seq}"); + Ok(()) + } + } +} diff --git a/src/agent/planner.rs b/src/agent/planner.rs new file mode 100644 index 0000000..9337001 --- /dev/null +++ b/src/agent/planner.rs @@ -0,0 +1,72 @@ +use serde_json::{json, Value}; +use thiserror::Error; + +use crate::pipe::Action; + +const BAIDU_URL: &str = "https://www.baidu.com"; +const BAIDU_DOMAIN: &str = "www.baidu.com"; +const BAIDU_INPUT_SELECTOR: &str = "#kw"; +const BAIDU_SEARCH_BUTTON_SELECTOR: &str = "#su"; + +#[derive(Debug, Clone, PartialEq)] +pub struct PlannedStep { + pub action: Action, + pub params: Value, + pub expected_domain: String, + pub log_message: String, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct TaskPlan { + pub summary: String, + pub steps: Vec, +} + +#[derive(Debug, Error, Clone, PartialEq, Eq)] +pub enum PlannerError { + #[error("unsupported instruction: {0}")] + UnsupportedInstruction(String), + #[error("missing search query in instruction")] + MissingQuery, +} + +pub fn plan_instruction(instruction: &str) -> Result { + let trimmed = instruction.trim(); + let query = trimmed + .strip_prefix("打开百度搜索") + .or_else(|| trimmed.strip_prefix("打开百度并搜索")) + .ok_or_else(|| PlannerError::UnsupportedInstruction(trimmed.to_string()))? + .trim(); + + if query.is_empty() { + return Err(PlannerError::MissingQuery); + } + + Ok(TaskPlan { + summary: format!("已在百度搜索{query}"), + steps: vec![ + PlannedStep { + action: Action::Navigate, + params: json!({ "url": BAIDU_URL }), + expected_domain: BAIDU_DOMAIN.to_string(), + log_message: "navigate https://www.baidu.com".to_string(), + }, + PlannedStep { + action: Action::Type, + params: json!({ + "selector": BAIDU_INPUT_SELECTOR, + "text": query, + "clear_first": true + }), + expected_domain: BAIDU_DOMAIN.to_string(), + log_message: format!("type {query} into {BAIDU_INPUT_SELECTOR}"), + }, + PlannedStep { + action: Action::Click, + params: json!({ "selector": BAIDU_SEARCH_BUTTON_SELECTOR }), + expected_domain: BAIDU_DOMAIN.to_string(), + log_message: format!("click {BAIDU_SEARCH_BUTTON_SELECTOR}"), + }, + ], + }) +} diff --git a/src/lib.rs b/src/lib.rs index 71d2aab..074cf0e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod agent; pub mod pipe; pub mod security; @@ -5,6 +6,7 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; +use agent::handle_browser_message; use pipe::{perform_handshake, BrowserPipeTool, PipeError, StdioTransport, Transport}; use security::MacPolicy; @@ -19,7 +21,7 @@ pub fn run() -> Result<(), PipeError> { let transport = Arc::new(StdioTransport::new(std::io::stdin(), std::io::stdout())); let handshake = perform_handshake(transport.as_ref(), Duration::from_secs(5))?; let mac_policy = MacPolicy::load_from_path(default_rules_path())?; - let _browser_tool = BrowserPipeTool::new(transport.clone(), mac_policy, handshake.session_key) + let browser_tool = BrowserPipeTool::new(transport.clone(), mac_policy, handshake.session_key) .with_response_timeout(Duration::from_secs(30)); eprintln!("sgclaw ready: agent_id={}", handshake.agent_id); @@ -27,7 +29,7 @@ pub fn run() -> Result<(), PipeError> { loop { match transport.recv_timeout(Duration::from_secs(3600)) { Ok(message) => { - eprintln!("ignoring unsolicited browser message: {:?}", message); + handle_browser_message(transport.as_ref(), &browser_tool, message)?; } Err(PipeError::Timeout) => continue, Err(PipeError::PipeClosed) => return Ok(()), diff --git a/tests/planner_test.rs b/tests/planner_test.rs new file mode 100644 index 0000000..9dc2f69 --- /dev/null +++ b/tests/planner_test.rs @@ -0,0 +1,41 @@ +use serde_json::json; +use sgclaw::agent::planner::{plan_instruction, PlannerError}; +use sgclaw::pipe::Action; + +#[test] +fn planner_converts_baidu_search_instruction_into_three_steps() { + let plan = plan_instruction("打开百度搜索天气").unwrap(); + + assert_eq!(plan.summary, "已在百度搜索天气"); + assert_eq!(plan.steps.len(), 3); + assert_eq!(plan.steps[0].action, Action::Navigate); + assert_eq!( + plan.steps[0].params, + json!({ "url": "https://www.baidu.com" }) + ); + assert_eq!(plan.steps[1].action, Action::Type); + assert_eq!( + plan.steps[1].params, + json!({ "selector": "#kw", "text": "天气", "clear_first": true }) + ); + assert_eq!(plan.steps[2].action, Action::Click); + assert_eq!(plan.steps[2].params, json!({ "selector": "#su" })); +} + +#[test] +fn planner_supports_baidu_search_variant_with_conjunction() { + let plan = plan_instruction("打开百度并搜索电网调度").unwrap(); + + assert_eq!(plan.summary, "已在百度搜索电网调度"); + assert_eq!(plan.steps[1].params["text"], "电网调度"); +} + +#[test] +fn planner_rejects_unrelated_instruction() { + let err = plan_instruction("打开谷歌搜索天气").unwrap_err(); + + assert_eq!( + err, + PlannerError::UnsupportedInstruction("打开谷歌搜索天气".to_string()) + ); +} diff --git a/tests/runtime_task_flow_test.rs b/tests/runtime_task_flow_test.rs new file mode 100644 index 0000000..af81bbc --- /dev/null +++ b/tests/runtime_task_flow_test.rs @@ -0,0 +1,113 @@ +mod common; + +use std::sync::Arc; +use std::time::Duration; + +use common::MockTransport; +use sgclaw::agent::handle_browser_message; +use sgclaw::pipe::{Action, AgentMessage, BrowserMessage, BrowserPipeTool, Timing}; +use sgclaw::security::MacPolicy; + +fn test_policy() -> MacPolicy { + MacPolicy::from_json_str( + r#"{ + "version": "1.0", + "domains": { "allowed": ["oa.example.com", "www.baidu.com"] }, + "pipe_actions": { + "allowed": ["click", "type", "navigate", "getText"], + "blocked": ["eval", "executeJsInPage"] + } + }"#, + ) + .unwrap() +} + +#[test] +fn submit_task_sends_three_commands_and_finishes_with_task_complete() { + let transport = Arc::new(MockTransport::new(vec![ + BrowserMessage::Response { + seq: 1, + success: true, + data: serde_json::json!({ "navigated": true }), + aom_snapshot: vec![], + timing: Timing { + queue_ms: 1, + exec_ms: 20, + }, + }, + BrowserMessage::Response { + seq: 2, + success: true, + data: serde_json::json!({ "typed": true }), + aom_snapshot: vec![], + timing: Timing { + queue_ms: 1, + exec_ms: 20, + }, + }, + BrowserMessage::Response { + seq: 3, + success: true, + data: serde_json::json!({ "clicked": true }), + aom_snapshot: vec![], + timing: Timing { + queue_ms: 1, + exec_ms: 20, + }, + }, + ])); + let tool = BrowserPipeTool::new( + transport.clone(), + test_policy(), + vec![1, 2, 3, 4, 5, 6, 7, 8], + ) + .with_response_timeout(Duration::from_secs(1)); + + handle_browser_message( + transport.as_ref(), + &tool, + BrowserMessage::SubmitTask { + instruction: "打开百度搜索天气".to_string(), + }, + ) + .unwrap(); + + let sent = transport.sent_messages(); + + assert_eq!(sent.len(), 7); + assert!(matches!( + &sent[0], + AgentMessage::LogEntry { level, message } + if level == "info" && message == "navigate https://www.baidu.com" + )); + assert!(matches!( + &sent[1], + AgentMessage::Command { seq, action, .. } + if *seq == 1 && action == &Action::Navigate + )); + assert!(matches!( + &sent[2], + AgentMessage::LogEntry { level, message } + if level == "info" && message == "type 天气 into #kw" + )); + assert!(matches!( + &sent[3], + AgentMessage::Command { seq, action, .. } + if *seq == 2 && action == &Action::Type + )); + assert!(matches!( + &sent[4], + AgentMessage::LogEntry { level, message } + if level == "info" && message == "click #su" + )); + assert!(matches!( + &sent[5], + AgentMessage::Command { seq, action, .. } + if *seq == 3 && action == &Action::Click + )); + assert!(matches!( + &sent[6], + AgentMessage::TaskComplete { success, summary } + if *success && summary == "已在百度搜索天气" + )); +}