fix: sanitize provider tool names

This commit is contained in:
zyl
2026-03-30 03:11:43 +08:00
parent 5db25b513e
commit dbb18a094c
5 changed files with 182 additions and 15 deletions

View File

@@ -1290,7 +1290,7 @@ fn compat_runtime_allows_read_skill_under_compact_mode_policy() {
assert!(tool_names.contains(&"browser_action".to_string()));
assert!(tool_names.contains(&"superrpa_browser".to_string()));
assert!(tool_names.contains(&"read_skill".to_string()));
assert!(tool_names.contains(&"zhihu-hotlist.extract_hotlist".to_string()));
assert!(!tool_names.contains(&"zhihu-hotlist_extract_hotlist".to_string()));
}
#[test]
@@ -1368,7 +1368,7 @@ top_n = "How many hotlist rows to extract."
assert!(tool_names.contains(&"browser_action".to_string()));
assert!(tool_names.contains(&"superrpa_browser".to_string()));
assert!(tool_names.contains(&"read_skill".to_string()));
assert!(tool_names.contains(&"workspace-zhihu-skill.extract_hotlist".to_string()));
assert!(tool_names.contains(&"workspace-zhihu-skill_extract_hotlist".to_string()));
}
#[test]
@@ -1383,7 +1383,7 @@ fn compat_runtime_executes_browser_script_skill_via_eval_without_gettext_probing
"id": "call_1",
"type": "function",
"function": {
"name": "workspace-zhihu-skill.extract_hotlist",
"name": "workspace-zhihu-skill_extract_hotlist",
"arguments": serde_json::to_string(&json!({
"expected_domain": "www.zhihu.com",
"top_n": "10"
@@ -1481,7 +1481,7 @@ return {
let tool_names = request_tool_names(&request_bodies[0]);
assert_eq!(summary, "已执行 browser_script skill");
assert!(tool_names.contains(&"workspace-zhihu-skill.extract_hotlist".to_string()));
assert!(tool_names.contains(&"workspace-zhihu-skill_extract_hotlist".to_string()));
assert!(sent.iter().any(|message| {
matches!(message, AgentMessage::LogEntry { level, message }
if level == "info" && message == "call workspace-zhihu-skill.extract_hotlist")
@@ -1544,7 +1544,7 @@ fn zhihu_hotlist_browser_skill_flow_does_not_expose_shell_or_glob_tools() {
assert!(tool_names.contains(&"superrpa_browser".to_string()));
assert!(tool_names.contains(&"browser_action".to_string()));
assert!(tool_names.contains(&"read_skill".to_string()));
assert!(tool_names.contains(&"zhihu-hotlist.extract_hotlist".to_string()));
assert!(tool_names.contains(&"zhihu-hotlist_extract_hotlist".to_string()));
assert!(!tool_names.contains(&"shell".to_string()));
assert!(!tool_names.contains(&"glob_search".to_string()));
}
@@ -1649,7 +1649,7 @@ fn browser_attached_export_flow_exposes_browser_and_office_tools_only() {
assert!(tool_names.contains(&"superrpa_browser".to_string()));
assert!(tool_names.contains(&"browser_action".to_string()));
assert!(tool_names.contains(&"read_skill".to_string()));
assert!(tool_names.contains(&"zhihu-hotlist.extract_hotlist".to_string()));
assert!(tool_names.contains(&"zhihu-hotlist_extract_hotlist".to_string()));
assert!(tool_names.contains(&"openxml_office".to_string()));
assert!(!tool_names.contains(&"shell".to_string()));
assert!(!tool_names.contains(&"glob_search".to_string()));
@@ -1704,7 +1704,7 @@ fn compat_runtime_allows_zhihu_hotlist_screen_export_tool_in_browser_profile() {
assert!(tool_names.contains(&"superrpa_browser".to_string()));
assert!(tool_names.contains(&"browser_action".to_string()));
assert!(tool_names.contains(&"read_skill".to_string()));
assert!(tool_names.contains(&"zhihu-hotlist.extract_hotlist".to_string()));
assert!(tool_names.contains(&"zhihu-hotlist_extract_hotlist".to_string()));
assert!(tool_names.contains(&"screen_html_export".to_string()));
assert!(!tool_names.contains(&"shell".to_string()));
assert!(!tool_names.contains(&"glob_search".to_string()));
@@ -1931,7 +1931,7 @@ fn handle_browser_message_executes_real_zhihu_hotlist_skill_flow() {
"id": "call_1",
"type": "function",
"function": {
"name": "zhihu-hotlist.extract_hotlist",
"name": "zhihu-hotlist_extract_hotlist",
"arguments": serde_json::to_string(&json!({
"expected_domain": "www.zhihu.com",
"top_n": "10"
@@ -2040,7 +2040,7 @@ fn handle_browser_message_chains_hotlist_skill_into_office_export_tool() {
"id": "call_1",
"type": "function",
"function": {
"name": "zhihu-hotlist.extract_hotlist",
"name": "zhihu-hotlist_extract_hotlist",
"arguments": serde_json::to_string(&json!({
"expected_domain": "www.zhihu.com",
"top_n": "10"

View File

@@ -774,7 +774,7 @@ impl Agent {
});
}
let response = match self
let mut response = match self
.provider
.chat(
ChatRequest {
@@ -795,6 +795,8 @@ impl Agent {
};
let (text, calls) = self.tool_dispatcher.parse_response(&response);
let calls = canonicalize_parsed_tool_calls(&self.tools, calls);
response.tool_calls = canonicalize_provider_tool_calls(&self.tools, response.tool_calls);
if calls.is_empty() {
let final_text = if text.is_empty() {
response.text.unwrap_or_default()
@@ -1030,7 +1032,7 @@ impl Agent {
// If streaming produced text, use it as the response and
// check for tool calls via the dispatcher.
let response = if got_stream {
let mut response = if got_stream {
// Build a synthetic ChatResponse from streamed text
crate::providers::ChatResponse {
text: Some(streamed_text),
@@ -1062,6 +1064,8 @@ impl Agent {
};
let (text, calls) = self.tool_dispatcher.parse_response(&response);
let calls = canonicalize_parsed_tool_calls(&self.tools, calls);
response.tool_calls = canonicalize_provider_tool_calls(&self.tools, response.tool_calls);
if calls.is_empty() {
let final_text = if text.is_empty() {
response.text.unwrap_or_default()
@@ -1202,6 +1206,42 @@ fn sanitize_final_text(text: &str) -> String {
result.join("\n\n")
}
fn resolve_registered_tool_name(tools: &[Box<dyn Tool>], raw: &str) -> Option<String> {
tools.iter()
.find(|tool| {
tool.name() == raw || crate::tools::provider_safe_tool_name(tool.name()) == raw
})
.map(|tool| tool.name().to_string())
}
fn canonicalize_parsed_tool_calls(
tools: &[Box<dyn Tool>],
calls: Vec<ParsedToolCall>,
) -> Vec<ParsedToolCall> {
calls.into_iter()
.map(|mut call| {
if let Some(canonical_name) = resolve_registered_tool_name(tools, &call.name) {
call.name = canonical_name;
}
call
})
.collect()
}
fn canonicalize_provider_tool_calls(
tools: &[Box<dyn Tool>],
calls: Vec<crate::providers::ToolCall>,
) -> Vec<crate::providers::ToolCall> {
calls.into_iter()
.map(|mut call| {
if let Some(canonical_name) = resolve_registered_tool_name(tools, &call.name) {
call.name = canonical_name;
}
call
})
.collect()
}
pub async fn run(
config: Config,
message: Option<String>,
@@ -1360,6 +1400,31 @@ mod tests {
}
}
struct MockDottedTool;
#[async_trait]
impl Tool for MockDottedTool {
fn name(&self) -> &str {
"zhihu-hotlist.extract_hotlist"
}
fn description(&self) -> &str {
"extract zhihu hotlist"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({"type": "object"})
}
async fn execute(&self, _args: serde_json::Value) -> Result<crate::tools::ToolResult> {
Ok(crate::tools::ToolResult {
success: true,
output: "hotlist-out".into(),
error: None,
})
}
}
struct StreamingDuplicateParagraphProvider;
#[async_trait]
@@ -1507,6 +1572,65 @@ mod tests {
.any(|msg| matches!(msg, ConversationMessage::ToolResults(_))));
}
#[tokio::test]
async fn turn_streamed_restores_original_tool_name_for_provider_safe_calls() {
let provider = Box::new(MockProvider {
responses: Mutex::new(vec![
crate::providers::ChatResponse {
text: Some(String::new()),
tool_calls: vec![crate::providers::ToolCall {
id: "tc1".into(),
name: "zhihu-hotlist_extract_hotlist".into(),
arguments: "{}".into(),
}],
usage: None,
reasoning_content: None,
},
crate::providers::ChatResponse {
text: Some("done".into()),
tool_calls: vec![],
usage: None,
reasoning_content: None,
},
]),
});
let memory_cfg = crate::config::MemoryConfig {
backend: "none".into(),
..crate::config::MemoryConfig::default()
};
let mem: Arc<dyn Memory> = Arc::from(
crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None)
.expect("memory creation should succeed with valid config"),
);
let observer: Arc<dyn Observer> = Arc::from(crate::observability::NoopObserver {});
let mut agent = Agent::builder()
.provider(provider)
.tools(vec![Box::new(MockDottedTool)])
.memory(mem)
.observer(observer)
.tool_dispatcher(Box::new(NativeToolDispatcher))
.workspace_dir(std::path::PathBuf::from("/tmp"))
.build()
.expect("agent builder should succeed with valid config");
let (event_tx, mut event_rx) = tokio::sync::mpsc::channel(8);
let response = agent.turn_streamed("导出知乎热榜", event_tx).await.unwrap();
assert_eq!(response, "done");
let tool_event = event_rx.recv().await.expect("tool event should be emitted");
assert!(matches!(
tool_event,
TurnEvent::ToolCall { ref name, .. } if name == "zhihu-hotlist.extract_hotlist"
));
assert!(agent.history().iter().any(|message| matches!(
message,
ConversationMessage::AssistantToolCalls { tool_calls, .. }
if tool_calls.iter().any(|call| call.name == "zhihu-hotlist.extract_hotlist")
)));
}
#[tokio::test]
async fn turn_streamed_sanitizes_duplicate_final_paragraphs() {
let provider = Box::new(StreamingDuplicateParagraphProvider);

View File

@@ -405,10 +405,11 @@ impl OpenAiCompatibleProvider {
tools
.iter()
.map(|tool| {
let provider_name = crate::tools::provider_safe_tool_name(&tool.name);
serde_json::json!({
"type": "function",
"function": {
"name": tool.name,
"name": provider_name,
"description": tool.description,
"parameters": tool.parameters
}
@@ -1321,10 +1322,11 @@ impl OpenAiCompatibleProvider {
items
.iter()
.map(|tool| {
let provider_name = crate::tools::provider_safe_tool_name(&tool.name);
serde_json::json!({
"type": "function",
"function": {
"name": tool.name,
"name": provider_name,
"description": tool.description,
"parameters": tool.parameters,
}
@@ -1387,7 +1389,7 @@ impl OpenAiCompatibleProvider {
id: Some(tc.id),
kind: Some("function".to_string()),
function: Some(Function {
name: Some(tc.name),
name: Some(crate::tools::provider_safe_tool_name(&tc.name)),
arguments: Some(tc.arguments),
}),
name: None,
@@ -3233,6 +3235,26 @@ mod tests {
assert_eq!(tools[0]["function"]["parameters"]["required"][0], "command");
}
#[test]
fn tool_specs_convert_invalid_function_names_to_provider_safe_names() {
let specs = vec![crate::tools::ToolSpec {
name: "zhihu-hotlist.extract_hotlist".to_string(),
description: "Extract Zhihu hotlist rows".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {"top_n": {"type": "string"}},
"required": ["top_n"]
}),
}];
let tools = OpenAiCompatibleProvider::tool_specs_to_openai_format(&specs);
assert_eq!(tools.len(), 1);
assert_eq!(
tools[0]["function"]["name"],
"zhihu-hotlist_extract_hotlist"
);
}
#[test]
fn request_serializes_with_tools() {
let tools = vec![serde_json::json!({

View File

@@ -204,7 +204,7 @@ pub use text_browser::TextBrowserTool;
pub use tool_search::ToolSearchTool;
pub use traits::Tool;
#[allow(unused_imports)]
pub use traits::{ToolResult, ToolSpec};
pub use traits::{provider_safe_tool_name, ToolResult, ToolSpec};
pub use verifiable_intent::VerifiableIntentTool;
pub use weather_tool::WeatherTool;
pub use web_fetch::WebFetchTool;

View File

@@ -17,6 +17,18 @@ pub struct ToolSpec {
pub parameters: serde_json::Value,
}
pub fn provider_safe_tool_name(name: &str) -> String {
name.chars()
.map(|ch| {
if ch.is_ascii_alphanumeric() || ch == '_' || ch == '-' {
ch
} else {
'_'
}
})
.collect()
}
/// Core tool trait — implement for any capability
#[async_trait]
pub trait Tool: Send + Sync {
@@ -118,4 +130,13 @@ mod tests {
assert!(!parsed.success);
assert_eq!(parsed.error.as_deref(), Some("boom"));
}
#[test]
fn provider_safe_tool_name_replaces_invalid_function_characters() {
assert_eq!(
provider_safe_tool_name("zhihu-hotlist.extract_hotlist"),
"zhihu-hotlist_extract_hotlist"
);
assert_eq!(provider_safe_tool_name("shell"), "shell");
}
}