diff --git a/src/agent/mod.rs b/src/agent/mod.rs index abfb265..0ad5e0d 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,5 +1,7 @@ pub mod planner; +pub mod runtime; +use crate::llm::DeepSeekProvider; use crate::pipe::{AgentMessage, BrowserMessage, BrowserPipeTool, PipeError, Transport}; pub fn execute_task( @@ -39,14 +41,31 @@ pub fn handle_browser_message( ) -> Result<(), PipeError> { match message { BrowserMessage::SubmitTask { instruction } => { - let completion = match execute_task(transport, browser_tool, &instruction) { - Ok(summary) => AgentMessage::TaskComplete { - success: true, - summary, + let completion = match DeepSeekProvider::from_env() { + Ok(provider) => match runtime::execute_task_with_provider( + transport, + browser_tool, + &provider, + &instruction, + ) { + Ok(summary) => AgentMessage::TaskComplete { + success: true, + summary, + }, + Err(err) => AgentMessage::TaskComplete { + success: false, + summary: err.to_string(), + }, }, - Err(err) => AgentMessage::TaskComplete { - success: false, - summary: err.to_string(), + Err(_) => match execute_task(transport, browser_tool, &instruction) { + Ok(summary) => AgentMessage::TaskComplete { + success: true, + summary, + }, + Err(err) => AgentMessage::TaskComplete { + success: false, + summary: err.to_string(), + }, }, }; transport.send(&completion) diff --git a/src/agent/runtime.rs b/src/agent/runtime.rs new file mode 100644 index 0000000..0a5f8fb --- /dev/null +++ b/src/agent/runtime.rs @@ -0,0 +1,152 @@ +use serde_json::{json, Map, Value}; + +use crate::llm::{ChatMessage, LlmError, LlmProvider, ToolDefinition, ToolFunctionCall}; +use crate::pipe::{Action, AgentMessage, BrowserPipeTool, PipeError, Transport}; + +const BROWSER_ACTION_TOOL_NAME: &str = "browser_action"; + +#[derive(Debug, Clone, PartialEq)] +struct BrowserActionCall { + action: Action, + expected_domain: String, + params: Value, +} + +pub fn execute_task_with_provider( + transport: &T, + browser_tool: &BrowserPipeTool, + provider: &P, + instruction: &str, +) -> Result { + let messages = vec![ + ChatMessage { + role: "system".to_string(), + content: "You are sgClaw. Use browser_action to complete the browser task." + .to_string(), + }, + ChatMessage { + role: "user".to_string(), + content: instruction.to_string(), + }, + ]; + let tools = vec![browser_action_tool_definition()]; + let calls = provider + .chat(&messages, &tools) + .map_err(map_llm_error_to_pipe_error)?; + + for call in calls { + let browser_call = parse_browser_action_call(call) + .map_err(|err| PipeError::Protocol(err.to_string()))?; + + transport.send(&AgentMessage::LogEntry { + level: "info".to_string(), + message: format!( + "{} {}", + browser_call.action.as_str(), + browser_call.expected_domain + ), + })?; + + let result = browser_tool.invoke( + browser_call.action, + browser_call.params, + &browser_call.expected_domain, + )?; + if !result.success { + return Err(PipeError::Protocol(format!( + "browser action failed: {}", + result.data + ))); + } + } + + Ok(format!("已通过 Agent 执行任务: {instruction}")) +} + +pub fn browser_action_tool_definition() -> ToolDefinition { + ToolDefinition { + name: BROWSER_ACTION_TOOL_NAME.to_string(), + description: "Execute browser actions in SuperRPA".to_string(), + parameters: json!({ + "type": "object", + "required": ["action", "expected_domain"], + "properties": { + "action": { "type": "string", "enum": ["click", "type", "navigate", "getText"] }, + "expected_domain": { "type": "string" }, + "selector": { "type": "string" }, + "text": { "type": "string" }, + "url": { "type": "string" }, + "clear_first": { "type": "boolean" } + } + }), + } +} + +fn parse_browser_action_call(call: ToolFunctionCall) -> Result { + if call.name != BROWSER_ACTION_TOOL_NAME { + return Err(RuntimeError::UnsupportedTool(call.name)); + } + + let mut args = match call.arguments { + Value::Object(args) => args, + other => { + return Err(RuntimeError::InvalidArguments(format!( + "expected object arguments, got {other}" + ))) + } + }; + + let action_name = take_required_string(&mut args, "action")?; + let expected_domain = take_required_string(&mut args, "expected_domain")?; + let action = parse_action(&action_name)?; + let params = Value::Object(action_params_from_args(args)); + + Ok(BrowserActionCall { + action, + expected_domain, + params, + }) +} + +fn map_llm_error_to_pipe_error(err: LlmError) -> PipeError { + PipeError::Protocol(err.to_string()) +} + +fn parse_action(action_name: &str) -> Result { + match action_name { + "click" => Ok(Action::Click), + "type" => Ok(Action::Type), + "navigate" => Ok(Action::Navigate), + "getText" => Ok(Action::GetText), + other => Err(RuntimeError::UnsupportedAction(other.to_string())), + } +} + +fn take_required_string( + args: &mut Map, + key: &'static str, +) -> Result { + match args.remove(key) { + Some(Value::String(value)) if !value.trim().is_empty() => Ok(value), + Some(other) => Err(RuntimeError::InvalidArguments(format!( + "{key} must be a non-empty string, got {other}" + ))), + None => Err(RuntimeError::MissingField(key)), + } +} + +fn action_params_from_args(args: Map) -> Map { + args +} + +#[derive(Debug, thiserror::Error)] +enum RuntimeError { + #[error("unsupported tool: {0}")] + UnsupportedTool(String), + #[error("unsupported action: {0}")] + UnsupportedAction(String), + #[error("missing required field: {0}")] + MissingField(&'static str), + #[error("invalid tool arguments: {0}")] + InvalidArguments(String), +} diff --git a/tests/agent_runtime_test.rs b/tests/agent_runtime_test.rs new file mode 100644 index 0000000..5058441 --- /dev/null +++ b/tests/agent_runtime_test.rs @@ -0,0 +1,134 @@ +mod common; + +use std::sync::Arc; +use std::time::Duration; + +use common::MockTransport; +use sgclaw::agent::runtime::{browser_action_tool_definition, execute_task_with_provider}; +use sgclaw::llm::{ChatMessage, LlmError, LlmProvider, ToolDefinition, ToolFunctionCall}; +use sgclaw::pipe::{Action, AgentMessage, BrowserMessage, BrowserPipeTool, Timing}; +use sgclaw::security::MacPolicy; + +struct FakeProvider { + calls: Vec, +} + +impl LlmProvider for FakeProvider { + fn chat( + &self, + _messages: &[ChatMessage], + _tools: &[ToolDefinition], + ) -> Result, LlmError> { + Ok(self.calls.clone()) + } +} + +fn test_policy() -> MacPolicy { + MacPolicy::from_json_str( + r#"{ + "version": "1.0", + "domains": { "allowed": ["www.baidu.com"] }, + "pipe_actions": { + "allowed": ["click", "type", "navigate", "getText"], + "blocked": [] + } + }"#, + ) + .unwrap() +} + +#[test] +fn browser_action_tool_definition_uses_expected_name() { + let tool = browser_action_tool_definition(); + + assert_eq!(tool.name, "browser_action"); + assert_eq!(tool.parameters["required"][0], "action"); + assert_eq!(tool.parameters["required"][1], "expected_domain"); +} + +#[test] +fn runtime_executes_provider_tool_calls_and_returns_summary() { + let transport = Arc::new(MockTransport::new(vec![ + BrowserMessage::Response { + seq: 1, + success: true, + data: serde_json::json!({ "navigated": true }), + aom_snapshot: vec![], + timing: Timing { + queue_ms: 1, + exec_ms: 10, + }, + }, + BrowserMessage::Response { + seq: 2, + success: true, + data: serde_json::json!({ "typed": true }), + aom_snapshot: vec![], + timing: Timing { + queue_ms: 1, + exec_ms: 10, + }, + }, + ])); + let browser_tool = BrowserPipeTool::new( + transport.clone(), + test_policy(), + vec![1, 2, 3, 4, 5, 6, 7, 8], + ) + .with_response_timeout(Duration::from_secs(1)); + let provider = FakeProvider { + calls: vec![ + ToolFunctionCall { + id: "call-1".to_string(), + name: "browser_action".to_string(), + arguments: serde_json::json!({ + "action": "navigate", + "expected_domain": "www.baidu.com", + "url": "https://www.baidu.com" + }), + }, + ToolFunctionCall { + id: "call-2".to_string(), + name: "browser_action".to_string(), + arguments: serde_json::json!({ + "action": "type", + "expected_domain": "www.baidu.com", + "selector": "#kw", + "text": "天气", + "clear_first": true + }), + }, + ], + }; + + let summary = execute_task_with_provider( + transport.as_ref(), + &browser_tool, + &provider, + "打开百度搜索天气", + ) + .unwrap(); + let sent = transport.sent_messages(); + + assert_eq!(summary, "已通过 Agent 执行任务: 打开百度搜索天气"); + assert!(matches!( + &sent[0], + AgentMessage::LogEntry { level, message } + if level == "info" && message == "navigate www.baidu.com" + )); + assert!(matches!( + &sent[1], + AgentMessage::Command { seq, action, .. } + if *seq == 1 && action == &Action::Navigate + )); + assert!(matches!( + &sent[2], + AgentMessage::LogEntry { level, message } + if level == "info" && message == "type www.baidu.com" + )); + assert!(matches!( + &sent[3], + AgentMessage::Command { seq, action, .. } + if *seq == 2 && action == &Action::Type + )); +}