feat: add phase1 task planner flow
This commit is contained in:
63
src/agent/mod.rs
Normal file
63
src/agent/mod.rs
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
pub mod planner;
|
||||||
|
|
||||||
|
use crate::pipe::{AgentMessage, BrowserMessage, BrowserPipeTool, PipeError, Transport};
|
||||||
|
|
||||||
|
pub fn execute_task<T: Transport>(
|
||||||
|
transport: &T,
|
||||||
|
browser_tool: &BrowserPipeTool<T>,
|
||||||
|
instruction: &str,
|
||||||
|
) -> Result<String, PipeError> {
|
||||||
|
let plan = planner::plan_instruction(instruction)
|
||||||
|
.map_err(|err| PipeError::Protocol(err.to_string()))?;
|
||||||
|
|
||||||
|
for step in &plan.steps {
|
||||||
|
transport.send(&AgentMessage::LogEntry {
|
||||||
|
level: "info".to_string(),
|
||||||
|
message: step.log_message.clone(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let result = browser_tool.invoke(
|
||||||
|
step.action.clone(),
|
||||||
|
step.params.clone(),
|
||||||
|
&step.expected_domain,
|
||||||
|
)?;
|
||||||
|
if !result.success {
|
||||||
|
return Err(PipeError::Protocol(format!(
|
||||||
|
"browser action failed: {}",
|
||||||
|
result.data
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(plan.summary)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn handle_browser_message<T: Transport>(
|
||||||
|
transport: &T,
|
||||||
|
browser_tool: &BrowserPipeTool<T>,
|
||||||
|
message: BrowserMessage,
|
||||||
|
) -> Result<(), PipeError> {
|
||||||
|
match message {
|
||||||
|
BrowserMessage::SubmitTask { instruction } => {
|
||||||
|
let completion = match execute_task(transport, browser_tool, &instruction) {
|
||||||
|
Ok(summary) => AgentMessage::TaskComplete {
|
||||||
|
success: true,
|
||||||
|
summary,
|
||||||
|
},
|
||||||
|
Err(err) => AgentMessage::TaskComplete {
|
||||||
|
success: false,
|
||||||
|
summary: err.to_string(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
transport.send(&completion)
|
||||||
|
}
|
||||||
|
BrowserMessage::Init { .. } => {
|
||||||
|
eprintln!("ignoring duplicate init after handshake");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
BrowserMessage::Response { seq, .. } => {
|
||||||
|
eprintln!("ignoring unsolicited response: seq={seq}");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
72
src/agent/planner.rs
Normal file
72
src/agent/planner.rs
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
use serde_json::{json, Value};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
use crate::pipe::Action;
|
||||||
|
|
||||||
|
const BAIDU_URL: &str = "https://www.baidu.com";
|
||||||
|
const BAIDU_DOMAIN: &str = "www.baidu.com";
|
||||||
|
const BAIDU_INPUT_SELECTOR: &str = "#kw";
|
||||||
|
const BAIDU_SEARCH_BUTTON_SELECTOR: &str = "#su";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub struct PlannedStep {
|
||||||
|
pub action: Action,
|
||||||
|
pub params: Value,
|
||||||
|
pub expected_domain: String,
|
||||||
|
pub log_message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub struct TaskPlan {
|
||||||
|
pub summary: String,
|
||||||
|
pub steps: Vec<PlannedStep>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error, Clone, PartialEq, Eq)]
|
||||||
|
pub enum PlannerError {
|
||||||
|
#[error("unsupported instruction: {0}")]
|
||||||
|
UnsupportedInstruction(String),
|
||||||
|
#[error("missing search query in instruction")]
|
||||||
|
MissingQuery,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn plan_instruction(instruction: &str) -> Result<TaskPlan, PlannerError> {
|
||||||
|
let trimmed = instruction.trim();
|
||||||
|
let query = trimmed
|
||||||
|
.strip_prefix("打开百度搜索")
|
||||||
|
.or_else(|| trimmed.strip_prefix("打开百度并搜索"))
|
||||||
|
.ok_or_else(|| PlannerError::UnsupportedInstruction(trimmed.to_string()))?
|
||||||
|
.trim();
|
||||||
|
|
||||||
|
if query.is_empty() {
|
||||||
|
return Err(PlannerError::MissingQuery);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(TaskPlan {
|
||||||
|
summary: format!("已在百度搜索{query}"),
|
||||||
|
steps: vec![
|
||||||
|
PlannedStep {
|
||||||
|
action: Action::Navigate,
|
||||||
|
params: json!({ "url": BAIDU_URL }),
|
||||||
|
expected_domain: BAIDU_DOMAIN.to_string(),
|
||||||
|
log_message: "navigate https://www.baidu.com".to_string(),
|
||||||
|
},
|
||||||
|
PlannedStep {
|
||||||
|
action: Action::Type,
|
||||||
|
params: json!({
|
||||||
|
"selector": BAIDU_INPUT_SELECTOR,
|
||||||
|
"text": query,
|
||||||
|
"clear_first": true
|
||||||
|
}),
|
||||||
|
expected_domain: BAIDU_DOMAIN.to_string(),
|
||||||
|
log_message: format!("type {query} into {BAIDU_INPUT_SELECTOR}"),
|
||||||
|
},
|
||||||
|
PlannedStep {
|
||||||
|
action: Action::Click,
|
||||||
|
params: json!({ "selector": BAIDU_SEARCH_BUTTON_SELECTOR }),
|
||||||
|
expected_domain: BAIDU_DOMAIN.to_string(),
|
||||||
|
log_message: format!("click {BAIDU_SEARCH_BUTTON_SELECTOR}"),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
pub mod agent;
|
||||||
pub mod pipe;
|
pub mod pipe;
|
||||||
pub mod security;
|
pub mod security;
|
||||||
|
|
||||||
@@ -5,6 +6,7 @@ use std::path::PathBuf;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use agent::handle_browser_message;
|
||||||
use pipe::{perform_handshake, BrowserPipeTool, PipeError, StdioTransport, Transport};
|
use pipe::{perform_handshake, BrowserPipeTool, PipeError, StdioTransport, Transport};
|
||||||
use security::MacPolicy;
|
use security::MacPolicy;
|
||||||
|
|
||||||
@@ -19,7 +21,7 @@ pub fn run() -> Result<(), PipeError> {
|
|||||||
let transport = Arc::new(StdioTransport::new(std::io::stdin(), std::io::stdout()));
|
let transport = Arc::new(StdioTransport::new(std::io::stdin(), std::io::stdout()));
|
||||||
let handshake = perform_handshake(transport.as_ref(), Duration::from_secs(5))?;
|
let handshake = perform_handshake(transport.as_ref(), Duration::from_secs(5))?;
|
||||||
let mac_policy = MacPolicy::load_from_path(default_rules_path())?;
|
let mac_policy = MacPolicy::load_from_path(default_rules_path())?;
|
||||||
let _browser_tool = BrowserPipeTool::new(transport.clone(), mac_policy, handshake.session_key)
|
let browser_tool = BrowserPipeTool::new(transport.clone(), mac_policy, handshake.session_key)
|
||||||
.with_response_timeout(Duration::from_secs(30));
|
.with_response_timeout(Duration::from_secs(30));
|
||||||
|
|
||||||
eprintln!("sgclaw ready: agent_id={}", handshake.agent_id);
|
eprintln!("sgclaw ready: agent_id={}", handshake.agent_id);
|
||||||
@@ -27,7 +29,7 @@ pub fn run() -> Result<(), PipeError> {
|
|||||||
loop {
|
loop {
|
||||||
match transport.recv_timeout(Duration::from_secs(3600)) {
|
match transport.recv_timeout(Duration::from_secs(3600)) {
|
||||||
Ok(message) => {
|
Ok(message) => {
|
||||||
eprintln!("ignoring unsolicited browser message: {:?}", message);
|
handle_browser_message(transport.as_ref(), &browser_tool, message)?;
|
||||||
}
|
}
|
||||||
Err(PipeError::Timeout) => continue,
|
Err(PipeError::Timeout) => continue,
|
||||||
Err(PipeError::PipeClosed) => return Ok(()),
|
Err(PipeError::PipeClosed) => return Ok(()),
|
||||||
|
|||||||
41
tests/planner_test.rs
Normal file
41
tests/planner_test.rs
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
use serde_json::json;
|
||||||
|
use sgclaw::agent::planner::{plan_instruction, PlannerError};
|
||||||
|
use sgclaw::pipe::Action;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn planner_converts_baidu_search_instruction_into_three_steps() {
|
||||||
|
let plan = plan_instruction("打开百度搜索天气").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(plan.summary, "已在百度搜索天气");
|
||||||
|
assert_eq!(plan.steps.len(), 3);
|
||||||
|
assert_eq!(plan.steps[0].action, Action::Navigate);
|
||||||
|
assert_eq!(
|
||||||
|
plan.steps[0].params,
|
||||||
|
json!({ "url": "https://www.baidu.com" })
|
||||||
|
);
|
||||||
|
assert_eq!(plan.steps[1].action, Action::Type);
|
||||||
|
assert_eq!(
|
||||||
|
plan.steps[1].params,
|
||||||
|
json!({ "selector": "#kw", "text": "天气", "clear_first": true })
|
||||||
|
);
|
||||||
|
assert_eq!(plan.steps[2].action, Action::Click);
|
||||||
|
assert_eq!(plan.steps[2].params, json!({ "selector": "#su" }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn planner_supports_baidu_search_variant_with_conjunction() {
|
||||||
|
let plan = plan_instruction("打开百度并搜索电网调度").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(plan.summary, "已在百度搜索电网调度");
|
||||||
|
assert_eq!(plan.steps[1].params["text"], "电网调度");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn planner_rejects_unrelated_instruction() {
|
||||||
|
let err = plan_instruction("打开谷歌搜索天气").unwrap_err();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
err,
|
||||||
|
PlannerError::UnsupportedInstruction("打开谷歌搜索天气".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
113
tests/runtime_task_flow_test.rs
Normal file
113
tests/runtime_task_flow_test.rs
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use common::MockTransport;
|
||||||
|
use sgclaw::agent::handle_browser_message;
|
||||||
|
use sgclaw::pipe::{Action, AgentMessage, BrowserMessage, BrowserPipeTool, Timing};
|
||||||
|
use sgclaw::security::MacPolicy;
|
||||||
|
|
||||||
|
fn test_policy() -> MacPolicy {
|
||||||
|
MacPolicy::from_json_str(
|
||||||
|
r#"{
|
||||||
|
"version": "1.0",
|
||||||
|
"domains": { "allowed": ["oa.example.com", "www.baidu.com"] },
|
||||||
|
"pipe_actions": {
|
||||||
|
"allowed": ["click", "type", "navigate", "getText"],
|
||||||
|
"blocked": ["eval", "executeJsInPage"]
|
||||||
|
}
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn submit_task_sends_three_commands_and_finishes_with_task_complete() {
|
||||||
|
let transport = Arc::new(MockTransport::new(vec![
|
||||||
|
BrowserMessage::Response {
|
||||||
|
seq: 1,
|
||||||
|
success: true,
|
||||||
|
data: serde_json::json!({ "navigated": true }),
|
||||||
|
aom_snapshot: vec![],
|
||||||
|
timing: Timing {
|
||||||
|
queue_ms: 1,
|
||||||
|
exec_ms: 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
BrowserMessage::Response {
|
||||||
|
seq: 2,
|
||||||
|
success: true,
|
||||||
|
data: serde_json::json!({ "typed": true }),
|
||||||
|
aom_snapshot: vec![],
|
||||||
|
timing: Timing {
|
||||||
|
queue_ms: 1,
|
||||||
|
exec_ms: 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
BrowserMessage::Response {
|
||||||
|
seq: 3,
|
||||||
|
success: true,
|
||||||
|
data: serde_json::json!({ "clicked": true }),
|
||||||
|
aom_snapshot: vec![],
|
||||||
|
timing: Timing {
|
||||||
|
queue_ms: 1,
|
||||||
|
exec_ms: 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]));
|
||||||
|
let tool = BrowserPipeTool::new(
|
||||||
|
transport.clone(),
|
||||||
|
test_policy(),
|
||||||
|
vec![1, 2, 3, 4, 5, 6, 7, 8],
|
||||||
|
)
|
||||||
|
.with_response_timeout(Duration::from_secs(1));
|
||||||
|
|
||||||
|
handle_browser_message(
|
||||||
|
transport.as_ref(),
|
||||||
|
&tool,
|
||||||
|
BrowserMessage::SubmitTask {
|
||||||
|
instruction: "打开百度搜索天气".to_string(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let sent = transport.sent_messages();
|
||||||
|
|
||||||
|
assert_eq!(sent.len(), 7);
|
||||||
|
assert!(matches!(
|
||||||
|
&sent[0],
|
||||||
|
AgentMessage::LogEntry { level, message }
|
||||||
|
if level == "info" && message == "navigate https://www.baidu.com"
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
&sent[1],
|
||||||
|
AgentMessage::Command { seq, action, .. }
|
||||||
|
if *seq == 1 && action == &Action::Navigate
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
&sent[2],
|
||||||
|
AgentMessage::LogEntry { level, message }
|
||||||
|
if level == "info" && message == "type 天气 into #kw"
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
&sent[3],
|
||||||
|
AgentMessage::Command { seq, action, .. }
|
||||||
|
if *seq == 2 && action == &Action::Type
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
&sent[4],
|
||||||
|
AgentMessage::LogEntry { level, message }
|
||||||
|
if level == "info" && message == "click #su"
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
&sent[5],
|
||||||
|
AgentMessage::Command { seq, action, .. }
|
||||||
|
if *seq == 3 && action == &Action::Click
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
&sent[6],
|
||||||
|
AgentMessage::TaskComplete { success, summary }
|
||||||
|
if *success && summary == "已在百度搜索天气"
|
||||||
|
));
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user