fix: sanitize provider tool names

This commit is contained in:
zyl
2026-03-30 03:11:43 +08:00
parent 5db25b513e
commit dbb18a094c
5 changed files with 182 additions and 15 deletions

View File

@@ -774,7 +774,7 @@ impl Agent {
});
}
let response = match self
let mut response = match self
.provider
.chat(
ChatRequest {
@@ -795,6 +795,8 @@ impl Agent {
};
let (text, calls) = self.tool_dispatcher.parse_response(&response);
let calls = canonicalize_parsed_tool_calls(&self.tools, calls);
response.tool_calls = canonicalize_provider_tool_calls(&self.tools, response.tool_calls);
if calls.is_empty() {
let final_text = if text.is_empty() {
response.text.unwrap_or_default()
@@ -1030,7 +1032,7 @@ impl Agent {
// If streaming produced text, use it as the response and
// check for tool calls via the dispatcher.
let response = if got_stream {
let mut response = if got_stream {
// Build a synthetic ChatResponse from streamed text
crate::providers::ChatResponse {
text: Some(streamed_text),
@@ -1062,6 +1064,8 @@ impl Agent {
};
let (text, calls) = self.tool_dispatcher.parse_response(&response);
let calls = canonicalize_parsed_tool_calls(&self.tools, calls);
response.tool_calls = canonicalize_provider_tool_calls(&self.tools, response.tool_calls);
if calls.is_empty() {
let final_text = if text.is_empty() {
response.text.unwrap_or_default()
@@ -1202,6 +1206,42 @@ fn sanitize_final_text(text: &str) -> String {
result.join("\n\n")
}
fn resolve_registered_tool_name(tools: &[Box<dyn Tool>], raw: &str) -> Option<String> {
tools.iter()
.find(|tool| {
tool.name() == raw || crate::tools::provider_safe_tool_name(tool.name()) == raw
})
.map(|tool| tool.name().to_string())
}
fn canonicalize_parsed_tool_calls(
tools: &[Box<dyn Tool>],
calls: Vec<ParsedToolCall>,
) -> Vec<ParsedToolCall> {
calls.into_iter()
.map(|mut call| {
if let Some(canonical_name) = resolve_registered_tool_name(tools, &call.name) {
call.name = canonical_name;
}
call
})
.collect()
}
fn canonicalize_provider_tool_calls(
tools: &[Box<dyn Tool>],
calls: Vec<crate::providers::ToolCall>,
) -> Vec<crate::providers::ToolCall> {
calls.into_iter()
.map(|mut call| {
if let Some(canonical_name) = resolve_registered_tool_name(tools, &call.name) {
call.name = canonical_name;
}
call
})
.collect()
}
pub async fn run(
config: Config,
message: Option<String>,
@@ -1360,6 +1400,31 @@ mod tests {
}
}
struct MockDottedTool;
#[async_trait]
impl Tool for MockDottedTool {
fn name(&self) -> &str {
"zhihu-hotlist.extract_hotlist"
}
fn description(&self) -> &str {
"extract zhihu hotlist"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({"type": "object"})
}
async fn execute(&self, _args: serde_json::Value) -> Result<crate::tools::ToolResult> {
Ok(crate::tools::ToolResult {
success: true,
output: "hotlist-out".into(),
error: None,
})
}
}
struct StreamingDuplicateParagraphProvider;
#[async_trait]
@@ -1507,6 +1572,65 @@ mod tests {
.any(|msg| matches!(msg, ConversationMessage::ToolResults(_))));
}
#[tokio::test]
async fn turn_streamed_restores_original_tool_name_for_provider_safe_calls() {
let provider = Box::new(MockProvider {
responses: Mutex::new(vec![
crate::providers::ChatResponse {
text: Some(String::new()),
tool_calls: vec![crate::providers::ToolCall {
id: "tc1".into(),
name: "zhihu-hotlist_extract_hotlist".into(),
arguments: "{}".into(),
}],
usage: None,
reasoning_content: None,
},
crate::providers::ChatResponse {
text: Some("done".into()),
tool_calls: vec![],
usage: None,
reasoning_content: None,
},
]),
});
let memory_cfg = crate::config::MemoryConfig {
backend: "none".into(),
..crate::config::MemoryConfig::default()
};
let mem: Arc<dyn Memory> = Arc::from(
crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None)
.expect("memory creation should succeed with valid config"),
);
let observer: Arc<dyn Observer> = Arc::from(crate::observability::NoopObserver {});
let mut agent = Agent::builder()
.provider(provider)
.tools(vec![Box::new(MockDottedTool)])
.memory(mem)
.observer(observer)
.tool_dispatcher(Box::new(NativeToolDispatcher))
.workspace_dir(std::path::PathBuf::from("/tmp"))
.build()
.expect("agent builder should succeed with valid config");
let (event_tx, mut event_rx) = tokio::sync::mpsc::channel(8);
let response = agent.turn_streamed("导出知乎热榜", event_tx).await.unwrap();
assert_eq!(response, "done");
let tool_event = event_rx.recv().await.expect("tool event should be emitted");
assert!(matches!(
tool_event,
TurnEvent::ToolCall { ref name, .. } if name == "zhihu-hotlist.extract_hotlist"
));
assert!(agent.history().iter().any(|message| matches!(
message,
ConversationMessage::AssistantToolCalls { tool_calls, .. }
if tool_calls.iter().any(|call| call.name == "zhihu-hotlist.extract_hotlist")
)));
}
#[tokio::test]
async fn turn_streamed_sanitizes_duplicate_final_paragraphs() {
let provider = Box::new(StreamingDuplicateParagraphProvider);