fix: sanitize provider tool names
This commit is contained in:
128
third_party/zeroclaw/src/agent/agent.rs
vendored
128
third_party/zeroclaw/src/agent/agent.rs
vendored
@@ -774,7 +774,7 @@ impl Agent {
|
||||
});
|
||||
}
|
||||
|
||||
let response = match self
|
||||
let mut response = match self
|
||||
.provider
|
||||
.chat(
|
||||
ChatRequest {
|
||||
@@ -795,6 +795,8 @@ impl Agent {
|
||||
};
|
||||
|
||||
let (text, calls) = self.tool_dispatcher.parse_response(&response);
|
||||
let calls = canonicalize_parsed_tool_calls(&self.tools, calls);
|
||||
response.tool_calls = canonicalize_provider_tool_calls(&self.tools, response.tool_calls);
|
||||
if calls.is_empty() {
|
||||
let final_text = if text.is_empty() {
|
||||
response.text.unwrap_or_default()
|
||||
@@ -1030,7 +1032,7 @@ impl Agent {
|
||||
|
||||
// If streaming produced text, use it as the response and
|
||||
// check for tool calls via the dispatcher.
|
||||
let response = if got_stream {
|
||||
let mut response = if got_stream {
|
||||
// Build a synthetic ChatResponse from streamed text
|
||||
crate::providers::ChatResponse {
|
||||
text: Some(streamed_text),
|
||||
@@ -1062,6 +1064,8 @@ impl Agent {
|
||||
};
|
||||
|
||||
let (text, calls) = self.tool_dispatcher.parse_response(&response);
|
||||
let calls = canonicalize_parsed_tool_calls(&self.tools, calls);
|
||||
response.tool_calls = canonicalize_provider_tool_calls(&self.tools, response.tool_calls);
|
||||
if calls.is_empty() {
|
||||
let final_text = if text.is_empty() {
|
||||
response.text.unwrap_or_default()
|
||||
@@ -1202,6 +1206,42 @@ fn sanitize_final_text(text: &str) -> String {
|
||||
result.join("\n\n")
|
||||
}
|
||||
|
||||
fn resolve_registered_tool_name(tools: &[Box<dyn Tool>], raw: &str) -> Option<String> {
|
||||
tools.iter()
|
||||
.find(|tool| {
|
||||
tool.name() == raw || crate::tools::provider_safe_tool_name(tool.name()) == raw
|
||||
})
|
||||
.map(|tool| tool.name().to_string())
|
||||
}
|
||||
|
||||
fn canonicalize_parsed_tool_calls(
|
||||
tools: &[Box<dyn Tool>],
|
||||
calls: Vec<ParsedToolCall>,
|
||||
) -> Vec<ParsedToolCall> {
|
||||
calls.into_iter()
|
||||
.map(|mut call| {
|
||||
if let Some(canonical_name) = resolve_registered_tool_name(tools, &call.name) {
|
||||
call.name = canonical_name;
|
||||
}
|
||||
call
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn canonicalize_provider_tool_calls(
|
||||
tools: &[Box<dyn Tool>],
|
||||
calls: Vec<crate::providers::ToolCall>,
|
||||
) -> Vec<crate::providers::ToolCall> {
|
||||
calls.into_iter()
|
||||
.map(|mut call| {
|
||||
if let Some(canonical_name) = resolve_registered_tool_name(tools, &call.name) {
|
||||
call.name = canonical_name;
|
||||
}
|
||||
call
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
config: Config,
|
||||
message: Option<String>,
|
||||
@@ -1360,6 +1400,31 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
struct MockDottedTool;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for MockDottedTool {
|
||||
fn name(&self) -> &str {
|
||||
"zhihu-hotlist.extract_hotlist"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"extract zhihu hotlist"
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({"type": "object"})
|
||||
}
|
||||
|
||||
async fn execute(&self, _args: serde_json::Value) -> Result<crate::tools::ToolResult> {
|
||||
Ok(crate::tools::ToolResult {
|
||||
success: true,
|
||||
output: "hotlist-out".into(),
|
||||
error: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct StreamingDuplicateParagraphProvider;
|
||||
|
||||
#[async_trait]
|
||||
@@ -1507,6 +1572,65 @@ mod tests {
|
||||
.any(|msg| matches!(msg, ConversationMessage::ToolResults(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_streamed_restores_original_tool_name_for_provider_safe_calls() {
|
||||
let provider = Box::new(MockProvider {
|
||||
responses: Mutex::new(vec![
|
||||
crate::providers::ChatResponse {
|
||||
text: Some(String::new()),
|
||||
tool_calls: vec![crate::providers::ToolCall {
|
||||
id: "tc1".into(),
|
||||
name: "zhihu-hotlist_extract_hotlist".into(),
|
||||
arguments: "{}".into(),
|
||||
}],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
crate::providers::ChatResponse {
|
||||
text: Some("done".into()),
|
||||
tool_calls: vec![],
|
||||
usage: None,
|
||||
reasoning_content: None,
|
||||
},
|
||||
]),
|
||||
});
|
||||
|
||||
let memory_cfg = crate::config::MemoryConfig {
|
||||
backend: "none".into(),
|
||||
..crate::config::MemoryConfig::default()
|
||||
};
|
||||
let mem: Arc<dyn Memory> = Arc::from(
|
||||
crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None)
|
||||
.expect("memory creation should succeed with valid config"),
|
||||
);
|
||||
|
||||
let observer: Arc<dyn Observer> = Arc::from(crate::observability::NoopObserver {});
|
||||
let mut agent = Agent::builder()
|
||||
.provider(provider)
|
||||
.tools(vec![Box::new(MockDottedTool)])
|
||||
.memory(mem)
|
||||
.observer(observer)
|
||||
.tool_dispatcher(Box::new(NativeToolDispatcher))
|
||||
.workspace_dir(std::path::PathBuf::from("/tmp"))
|
||||
.build()
|
||||
.expect("agent builder should succeed with valid config");
|
||||
|
||||
let (event_tx, mut event_rx) = tokio::sync::mpsc::channel(8);
|
||||
let response = agent.turn_streamed("导出知乎热榜", event_tx).await.unwrap();
|
||||
|
||||
assert_eq!(response, "done");
|
||||
let tool_event = event_rx.recv().await.expect("tool event should be emitted");
|
||||
assert!(matches!(
|
||||
tool_event,
|
||||
TurnEvent::ToolCall { ref name, .. } if name == "zhihu-hotlist.extract_hotlist"
|
||||
));
|
||||
assert!(agent.history().iter().any(|message| matches!(
|
||||
message,
|
||||
ConversationMessage::AssistantToolCalls { tool_calls, .. }
|
||||
if tool_calls.iter().any(|call| call.name == "zhihu-hotlist.extract_hotlist")
|
||||
)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn turn_streamed_sanitizes_duplicate_final_paragraphs() {
|
||||
let provider = Box::new(StreamingDuplicateParagraphProvider);
|
||||
|
||||
Reference in New Issue
Block a user