feat: add provider-backed agent runtime
This commit is contained in:
@@ -1,5 +1,7 @@
|
|||||||
pub mod planner;
|
pub mod planner;
|
||||||
|
pub mod runtime;
|
||||||
|
|
||||||
|
use crate::llm::DeepSeekProvider;
|
||||||
use crate::pipe::{AgentMessage, BrowserMessage, BrowserPipeTool, PipeError, Transport};
|
use crate::pipe::{AgentMessage, BrowserMessage, BrowserPipeTool, PipeError, Transport};
|
||||||
|
|
||||||
pub fn execute_task<T: Transport>(
|
pub fn execute_task<T: Transport>(
|
||||||
@@ -39,7 +41,13 @@ pub fn handle_browser_message<T: Transport>(
|
|||||||
) -> Result<(), PipeError> {
|
) -> Result<(), PipeError> {
|
||||||
match message {
|
match message {
|
||||||
BrowserMessage::SubmitTask { instruction } => {
|
BrowserMessage::SubmitTask { instruction } => {
|
||||||
let completion = match execute_task(transport, browser_tool, &instruction) {
|
let completion = match DeepSeekProvider::from_env() {
|
||||||
|
Ok(provider) => match runtime::execute_task_with_provider(
|
||||||
|
transport,
|
||||||
|
browser_tool,
|
||||||
|
&provider,
|
||||||
|
&instruction,
|
||||||
|
) {
|
||||||
Ok(summary) => AgentMessage::TaskComplete {
|
Ok(summary) => AgentMessage::TaskComplete {
|
||||||
success: true,
|
success: true,
|
||||||
summary,
|
summary,
|
||||||
@@ -48,6 +56,17 @@ pub fn handle_browser_message<T: Transport>(
|
|||||||
success: false,
|
success: false,
|
||||||
summary: err.to_string(),
|
summary: err.to_string(),
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
Err(_) => match execute_task(transport, browser_tool, &instruction) {
|
||||||
|
Ok(summary) => AgentMessage::TaskComplete {
|
||||||
|
success: true,
|
||||||
|
summary,
|
||||||
|
},
|
||||||
|
Err(err) => AgentMessage::TaskComplete {
|
||||||
|
success: false,
|
||||||
|
summary: err.to_string(),
|
||||||
|
},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
transport.send(&completion)
|
transport.send(&completion)
|
||||||
}
|
}
|
||||||
|
|||||||
152
src/agent/runtime.rs
Normal file
152
src/agent/runtime.rs
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
use serde_json::{json, Map, Value};
|
||||||
|
|
||||||
|
use crate::llm::{ChatMessage, LlmError, LlmProvider, ToolDefinition, ToolFunctionCall};
|
||||||
|
use crate::pipe::{Action, AgentMessage, BrowserPipeTool, PipeError, Transport};
|
||||||
|
|
||||||
|
const BROWSER_ACTION_TOOL_NAME: &str = "browser_action";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
struct BrowserActionCall {
|
||||||
|
action: Action,
|
||||||
|
expected_domain: String,
|
||||||
|
params: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn execute_task_with_provider<P: LlmProvider, T: Transport>(
|
||||||
|
transport: &T,
|
||||||
|
browser_tool: &BrowserPipeTool<T>,
|
||||||
|
provider: &P,
|
||||||
|
instruction: &str,
|
||||||
|
) -> Result<String, PipeError> {
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage {
|
||||||
|
role: "system".to_string(),
|
||||||
|
content: "You are sgClaw. Use browser_action to complete the browser task."
|
||||||
|
.to_string(),
|
||||||
|
},
|
||||||
|
ChatMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: instruction.to_string(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
let tools = vec![browser_action_tool_definition()];
|
||||||
|
let calls = provider
|
||||||
|
.chat(&messages, &tools)
|
||||||
|
.map_err(map_llm_error_to_pipe_error)?;
|
||||||
|
|
||||||
|
for call in calls {
|
||||||
|
let browser_call = parse_browser_action_call(call)
|
||||||
|
.map_err(|err| PipeError::Protocol(err.to_string()))?;
|
||||||
|
|
||||||
|
transport.send(&AgentMessage::LogEntry {
|
||||||
|
level: "info".to_string(),
|
||||||
|
message: format!(
|
||||||
|
"{} {}",
|
||||||
|
browser_call.action.as_str(),
|
||||||
|
browser_call.expected_domain
|
||||||
|
),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let result = browser_tool.invoke(
|
||||||
|
browser_call.action,
|
||||||
|
browser_call.params,
|
||||||
|
&browser_call.expected_domain,
|
||||||
|
)?;
|
||||||
|
if !result.success {
|
||||||
|
return Err(PipeError::Protocol(format!(
|
||||||
|
"browser action failed: {}",
|
||||||
|
result.data
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(format!("已通过 Agent 执行任务: {instruction}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn browser_action_tool_definition() -> ToolDefinition {
|
||||||
|
ToolDefinition {
|
||||||
|
name: BROWSER_ACTION_TOOL_NAME.to_string(),
|
||||||
|
description: "Execute browser actions in SuperRPA".to_string(),
|
||||||
|
parameters: json!({
|
||||||
|
"type": "object",
|
||||||
|
"required": ["action", "expected_domain"],
|
||||||
|
"properties": {
|
||||||
|
"action": { "type": "string", "enum": ["click", "type", "navigate", "getText"] },
|
||||||
|
"expected_domain": { "type": "string" },
|
||||||
|
"selector": { "type": "string" },
|
||||||
|
"text": { "type": "string" },
|
||||||
|
"url": { "type": "string" },
|
||||||
|
"clear_first": { "type": "boolean" }
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_browser_action_call(call: ToolFunctionCall) -> Result<BrowserActionCall, RuntimeError> {
|
||||||
|
if call.name != BROWSER_ACTION_TOOL_NAME {
|
||||||
|
return Err(RuntimeError::UnsupportedTool(call.name));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut args = match call.arguments {
|
||||||
|
Value::Object(args) => args,
|
||||||
|
other => {
|
||||||
|
return Err(RuntimeError::InvalidArguments(format!(
|
||||||
|
"expected object arguments, got {other}"
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let action_name = take_required_string(&mut args, "action")?;
|
||||||
|
let expected_domain = take_required_string(&mut args, "expected_domain")?;
|
||||||
|
let action = parse_action(&action_name)?;
|
||||||
|
let params = Value::Object(action_params_from_args(args));
|
||||||
|
|
||||||
|
Ok(BrowserActionCall {
|
||||||
|
action,
|
||||||
|
expected_domain,
|
||||||
|
params,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn map_llm_error_to_pipe_error(err: LlmError) -> PipeError {
|
||||||
|
PipeError::Protocol(err.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_action(action_name: &str) -> Result<Action, RuntimeError> {
|
||||||
|
match action_name {
|
||||||
|
"click" => Ok(Action::Click),
|
||||||
|
"type" => Ok(Action::Type),
|
||||||
|
"navigate" => Ok(Action::Navigate),
|
||||||
|
"getText" => Ok(Action::GetText),
|
||||||
|
other => Err(RuntimeError::UnsupportedAction(other.to_string())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn take_required_string(
|
||||||
|
args: &mut Map<String, Value>,
|
||||||
|
key: &'static str,
|
||||||
|
) -> Result<String, RuntimeError> {
|
||||||
|
match args.remove(key) {
|
||||||
|
Some(Value::String(value)) if !value.trim().is_empty() => Ok(value),
|
||||||
|
Some(other) => Err(RuntimeError::InvalidArguments(format!(
|
||||||
|
"{key} must be a non-empty string, got {other}"
|
||||||
|
))),
|
||||||
|
None => Err(RuntimeError::MissingField(key)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn action_params_from_args(args: Map<String, Value>) -> Map<String, Value> {
|
||||||
|
args
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
enum RuntimeError {
|
||||||
|
#[error("unsupported tool: {0}")]
|
||||||
|
UnsupportedTool(String),
|
||||||
|
#[error("unsupported action: {0}")]
|
||||||
|
UnsupportedAction(String),
|
||||||
|
#[error("missing required field: {0}")]
|
||||||
|
MissingField(&'static str),
|
||||||
|
#[error("invalid tool arguments: {0}")]
|
||||||
|
InvalidArguments(String),
|
||||||
|
}
|
||||||
134
tests/agent_runtime_test.rs
Normal file
134
tests/agent_runtime_test.rs
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use common::MockTransport;
|
||||||
|
use sgclaw::agent::runtime::{browser_action_tool_definition, execute_task_with_provider};
|
||||||
|
use sgclaw::llm::{ChatMessage, LlmError, LlmProvider, ToolDefinition, ToolFunctionCall};
|
||||||
|
use sgclaw::pipe::{Action, AgentMessage, BrowserMessage, BrowserPipeTool, Timing};
|
||||||
|
use sgclaw::security::MacPolicy;
|
||||||
|
|
||||||
|
struct FakeProvider {
|
||||||
|
calls: Vec<ToolFunctionCall>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlmProvider for FakeProvider {
|
||||||
|
fn chat(
|
||||||
|
&self,
|
||||||
|
_messages: &[ChatMessage],
|
||||||
|
_tools: &[ToolDefinition],
|
||||||
|
) -> Result<Vec<ToolFunctionCall>, LlmError> {
|
||||||
|
Ok(self.calls.clone())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_policy() -> MacPolicy {
|
||||||
|
MacPolicy::from_json_str(
|
||||||
|
r#"{
|
||||||
|
"version": "1.0",
|
||||||
|
"domains": { "allowed": ["www.baidu.com"] },
|
||||||
|
"pipe_actions": {
|
||||||
|
"allowed": ["click", "type", "navigate", "getText"],
|
||||||
|
"blocked": []
|
||||||
|
}
|
||||||
|
}"#,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn browser_action_tool_definition_uses_expected_name() {
|
||||||
|
let tool = browser_action_tool_definition();
|
||||||
|
|
||||||
|
assert_eq!(tool.name, "browser_action");
|
||||||
|
assert_eq!(tool.parameters["required"][0], "action");
|
||||||
|
assert_eq!(tool.parameters["required"][1], "expected_domain");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn runtime_executes_provider_tool_calls_and_returns_summary() {
|
||||||
|
let transport = Arc::new(MockTransport::new(vec![
|
||||||
|
BrowserMessage::Response {
|
||||||
|
seq: 1,
|
||||||
|
success: true,
|
||||||
|
data: serde_json::json!({ "navigated": true }),
|
||||||
|
aom_snapshot: vec![],
|
||||||
|
timing: Timing {
|
||||||
|
queue_ms: 1,
|
||||||
|
exec_ms: 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
BrowserMessage::Response {
|
||||||
|
seq: 2,
|
||||||
|
success: true,
|
||||||
|
data: serde_json::json!({ "typed": true }),
|
||||||
|
aom_snapshot: vec![],
|
||||||
|
timing: Timing {
|
||||||
|
queue_ms: 1,
|
||||||
|
exec_ms: 10,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]));
|
||||||
|
let browser_tool = BrowserPipeTool::new(
|
||||||
|
transport.clone(),
|
||||||
|
test_policy(),
|
||||||
|
vec![1, 2, 3, 4, 5, 6, 7, 8],
|
||||||
|
)
|
||||||
|
.with_response_timeout(Duration::from_secs(1));
|
||||||
|
let provider = FakeProvider {
|
||||||
|
calls: vec![
|
||||||
|
ToolFunctionCall {
|
||||||
|
id: "call-1".to_string(),
|
||||||
|
name: "browser_action".to_string(),
|
||||||
|
arguments: serde_json::json!({
|
||||||
|
"action": "navigate",
|
||||||
|
"expected_domain": "www.baidu.com",
|
||||||
|
"url": "https://www.baidu.com"
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
ToolFunctionCall {
|
||||||
|
id: "call-2".to_string(),
|
||||||
|
name: "browser_action".to_string(),
|
||||||
|
arguments: serde_json::json!({
|
||||||
|
"action": "type",
|
||||||
|
"expected_domain": "www.baidu.com",
|
||||||
|
"selector": "#kw",
|
||||||
|
"text": "天气",
|
||||||
|
"clear_first": true
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
let summary = execute_task_with_provider(
|
||||||
|
transport.as_ref(),
|
||||||
|
&browser_tool,
|
||||||
|
&provider,
|
||||||
|
"打开百度搜索天气",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let sent = transport.sent_messages();
|
||||||
|
|
||||||
|
assert_eq!(summary, "已通过 Agent 执行任务: 打开百度搜索天气");
|
||||||
|
assert!(matches!(
|
||||||
|
&sent[0],
|
||||||
|
AgentMessage::LogEntry { level, message }
|
||||||
|
if level == "info" && message == "navigate www.baidu.com"
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
&sent[1],
|
||||||
|
AgentMessage::Command { seq, action, .. }
|
||||||
|
if *seq == 1 && action == &Action::Navigate
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
&sent[2],
|
||||||
|
AgentMessage::LogEntry { level, message }
|
||||||
|
if level == "info" && message == "type www.baidu.com"
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
&sent[3],
|
||||||
|
AgentMessage::Command { seq, action, .. }
|
||||||
|
if *seq == 2 && action == &Action::Type
|
||||||
|
));
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user