feat: refactor sgclaw around zeroclaw compat runtime
This commit is contained in:
2057
third_party/zeroclaw/src/providers/anthropic.rs
vendored
Normal file
2057
third_party/zeroclaw/src/providers/anthropic.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
759
third_party/zeroclaw/src/providers/azure_openai.rs
vendored
Normal file
759
third_party/zeroclaw/src/providers/azure_openai.rs
vendored
Normal file
@@ -0,0 +1,759 @@
|
||||
use crate::providers::traits::{
|
||||
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||
Provider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, ToolsPayload,
|
||||
};
|
||||
use crate::tools::ToolSpec;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const DEFAULT_API_VERSION: &str = "2024-08-01-preview";
|
||||
|
||||
pub struct AzureOpenAiProvider {
|
||||
credential: Option<String>,
|
||||
resource_name: String,
|
||||
deployment_name: String,
|
||||
api_version: String,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
messages: Vec<Message>,
|
||||
temperature: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
impl ResponseMessage {
|
||||
fn effective_content(&self) -> String {
|
||||
match &self.content {
|
||||
Some(c) if !c.is_empty() => c.clone(),
|
||||
_ => self.reasoning_content.clone().unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeChatRequest {
|
||||
messages: Vec<NativeMessage>,
|
||||
temperature: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<NativeToolSpec>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_choice: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeMessage {
|
||||
role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
reasoning_content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeToolSpec {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
function: NativeToolFunctionSpec,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeToolFunctionSpec {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
fn parse_native_tool_spec(value: serde_json::Value) -> anyhow::Result<NativeToolSpec> {
|
||||
let spec: NativeToolSpec = serde_json::from_value(value)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid Azure OpenAI tool specification: {e}"))?;
|
||||
|
||||
if spec.kind != "function" {
|
||||
anyhow::bail!(
|
||||
"Invalid Azure OpenAI tool specification: unsupported tool type '{}', expected 'function'",
|
||||
spec.kind
|
||||
);
|
||||
}
|
||||
|
||||
Ok(spec)
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
kind: Option<String>,
|
||||
function: NativeFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeFunctionCall {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NativeChatResponse {
|
||||
choices: Vec<NativeChoice>,
|
||||
#[serde(default)]
|
||||
usage: Option<UsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UsageInfo {
|
||||
#[serde(default)]
|
||||
prompt_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
completion_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NativeChoice {
|
||||
message: NativeResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NativeResponseMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
reasoning_content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
impl NativeResponseMessage {
|
||||
fn effective_content(&self) -> Option<String> {
|
||||
match &self.content {
|
||||
Some(c) if !c.is_empty() => Some(c.clone()),
|
||||
_ => self.reasoning_content.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AzureOpenAiProvider {
|
||||
pub fn new(
|
||||
credential: Option<&str>,
|
||||
resource_name: &str,
|
||||
deployment_name: &str,
|
||||
api_version: Option<&str>,
|
||||
) -> Self {
|
||||
let version = api_version.unwrap_or(DEFAULT_API_VERSION);
|
||||
let base_url = format!(
|
||||
"https://{}.openai.azure.com/openai/deployments/{}",
|
||||
resource_name, deployment_name
|
||||
);
|
||||
Self {
|
||||
credential: credential.map(ToString::to_string),
|
||||
resource_name: resource_name.to_string(),
|
||||
deployment_name: deployment_name.to_string(),
|
||||
api_version: version.to_string(),
|
||||
base_url,
|
||||
}
|
||||
}
|
||||
|
||||
fn chat_completions_url(&self) -> String {
|
||||
format!(
|
||||
"{}/chat/completions?api-version={}",
|
||||
self.base_url, self.api_version
|
||||
)
|
||||
}
|
||||
|
||||
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
|
||||
tools.map(|items| {
|
||||
items
|
||||
.iter()
|
||||
.map(|tool| NativeToolSpec {
|
||||
kind: "function".to_string(),
|
||||
function: NativeToolFunctionSpec {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
parameters: tool.parameters.clone(),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|m| {
|
||||
if m.role == "assistant" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
|
||||
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||
if let Ok(parsed_calls) =
|
||||
serde_json::from_value::<Vec<ProviderToolCall>>(
|
||||
tool_calls_value.clone(),
|
||||
)
|
||||
{
|
||||
let tool_calls = parsed_calls
|
||||
.into_iter()
|
||||
.map(|tc| NativeToolCall {
|
||||
id: Some(tc.id),
|
||||
kind: Some("function".to_string()),
|
||||
function: NativeFunctionCall {
|
||||
name: tc.name,
|
||||
arguments: tc.arguments,
|
||||
},
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
let reasoning_content = value
|
||||
.get("reasoning_content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
return NativeMessage {
|
||||
role: "assistant".to_string(),
|
||||
content,
|
||||
tool_call_id: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
reasoning_content,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if m.role == "tool" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
|
||||
let tool_call_id = value
|
||||
.get("tool_call_id")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
return NativeMessage {
|
||||
role: "tool".to_string(),
|
||||
content,
|
||||
tool_call_id,
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
NativeMessage {
|
||||
role: m.role.clone(),
|
||||
content: Some(m.content.clone()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
reasoning_content: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse {
|
||||
let text = message.effective_content();
|
||||
let reasoning_content = message.reasoning_content.clone();
|
||||
let tool_calls = message
|
||||
.tool_calls
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|tc| ProviderToolCall {
|
||||
id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
ProviderChatResponse {
|
||||
text,
|
||||
tool_calls,
|
||||
usage: None,
|
||||
reasoning_content,
|
||||
}
|
||||
}
|
||||
|
||||
fn http_client(&self) -> Client {
|
||||
crate::config::build_runtime_proxy_client_with_timeouts("provider.azure_openai", 120, 10)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for AzureOpenAiProvider {
|
||||
fn capabilities(&self) -> ProviderCapabilities {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
vision: true,
|
||||
prompt_caching: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
|
||||
ToolsPayload::OpenAI {
|
||||
tools: tools
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
serde_json::json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn supports_vision(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
_model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(sys) = system_prompt {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: sys.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: message.to_string(),
|
||||
});
|
||||
|
||||
let request = ChatRequest {
|
||||
messages,
|
||||
temperature,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http_client()
|
||||
.post(self.chat_completions_url())
|
||||
.header("api-key", credential.as_str())
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("Azure OpenAI", response).await);
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
|
||||
chat_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.effective_content())
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ProviderChatRequest<'_>,
|
||||
_model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
|
||||
)
|
||||
})?;
|
||||
|
||||
let tools = Self::convert_tools(request.tools);
|
||||
let native_request = NativeChatRequest {
|
||||
messages: Self::convert_messages(request.messages),
|
||||
temperature,
|
||||
tool_choice: tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http_client()
|
||||
.post(self.chat_completions_url())
|
||||
.header("api-key", credential.as_str())
|
||||
.json(&native_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("Azure OpenAI", response).await);
|
||||
}
|
||||
|
||||
let native_response: NativeChatResponse = response.json().await?;
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?;
|
||||
let mut result = Self::parse_native_response(message);
|
||||
result.usage = usage;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn chat_with_tools(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
tools: &[serde_json::Value],
|
||||
_model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
|
||||
)
|
||||
})?;
|
||||
|
||||
let native_tools: Option<Vec<NativeToolSpec>> = if tools.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
tools
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(parse_native_tool_spec)
|
||||
.collect::<Result<Vec<_>, _>>()?,
|
||||
)
|
||||
};
|
||||
|
||||
let native_request = NativeChatRequest {
|
||||
messages: Self::convert_messages(messages),
|
||||
temperature,
|
||||
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools: native_tools,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http_client()
|
||||
.post(self.chat_completions_url())
|
||||
.header("api-key", credential.as_str())
|
||||
.json(&native_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("Azure OpenAI", response).await);
|
||||
}
|
||||
|
||||
let native_response: NativeChatResponse = response.json().await?;
|
||||
let usage = native_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let message = native_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?;
|
||||
let mut result = Self::parse_native_response(message);
|
||||
result.usage = usage;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
// Azure OpenAI does not have a lightweight models endpoint,
|
||||
// so warmup is a no-op to avoid unnecessary API calls.
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn url_construction_default_version() {
|
||||
let p = AzureOpenAiProvider::new(Some("test-key"), "my-resource", "gpt-4o", None);
|
||||
assert_eq!(
|
||||
p.chat_completions_url(),
|
||||
"https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn url_construction_custom_version() {
|
||||
let p = AzureOpenAiProvider::new(
|
||||
Some("test-key"),
|
||||
"my-resource",
|
||||
"gpt-4o",
|
||||
Some("2024-06-01"),
|
||||
);
|
||||
assert_eq!(
|
||||
p.chat_completions_url(),
|
||||
"https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-06-01"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn url_construction_preserves_resource_and_deployment() {
|
||||
let p = AzureOpenAiProvider::new(Some("key"), "contoso-ai", "my-gpt35-deployment", None);
|
||||
let url = p.chat_completions_url();
|
||||
assert!(url.contains("contoso-ai.openai.azure.com"));
|
||||
assert!(url.contains("/deployments/my-gpt35-deployment/"));
|
||||
assert!(url.contains("api-version=2024-08-01-preview"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_header_uses_api_key_not_bearer() {
|
||||
// This test verifies the provider stores the credential correctly
|
||||
// and that the auth header name is "api-key" (verified via the
|
||||
// implementation in chat_with_system which uses .header("api-key", ...)).
|
||||
let p = AzureOpenAiProvider::new(Some("my-azure-key"), "resource", "deployment", None);
|
||||
assert_eq!(p.credential.as_deref(), Some("my-azure-key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_credential() {
|
||||
let p = AzureOpenAiProvider::new(
|
||||
Some("azure-test-credential"),
|
||||
"resource",
|
||||
"deployment",
|
||||
None,
|
||||
);
|
||||
assert_eq!(p.credential.as_deref(), Some("azure-test-credential"));
|
||||
assert_eq!(p.resource_name, "resource");
|
||||
assert_eq!(p.deployment_name, "deployment");
|
||||
assert_eq!(p.api_version, DEFAULT_API_VERSION);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_credential() {
|
||||
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
|
||||
assert!(p.credential.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_fails_without_key() {
|
||||
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
|
||||
let result = p.chat_with_system(None, "hello", "gpt-4o", 0.7).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("API key not set"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_system_fails_without_key() {
|
||||
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
|
||||
let result = p
|
||||
.chat_with_system(Some("You are ZeroClaw"), "test", "gpt-4o", 0.5)
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_with_system_message() {
|
||||
let req = ChatRequest {
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: "You are ZeroClaw".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "hello".to_string(),
|
||||
},
|
||||
],
|
||||
temperature: 0.7,
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"role\":\"system\""));
|
||||
assert!(json.contains("\"role\":\"user\""));
|
||||
// Azure requests should NOT contain a model field (deployment is in the URL)
|
||||
assert!(!json.contains("\"model\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_without_system() {
|
||||
let req = ChatRequest {
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: "hello".to_string(),
|
||||
}],
|
||||
temperature: 0.0,
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(!json.contains("system"));
|
||||
assert!(json.contains("\"temperature\":0.0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes_single_choice() {
|
||||
let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.choices.len(), 1);
|
||||
assert_eq!(resp.choices[0].message.effective_content(), "Hi!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes_empty_choices() {
|
||||
let json = r#"{"choices":[]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.choices.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes_multiple_choices() {
|
||||
let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.choices.len(), 2);
|
||||
assert_eq!(resp.choices[0].message.effective_content(), "A");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_response_parsing() {
|
||||
let json = r#"{"choices":[{"message":{
|
||||
"content":"Let me check",
|
||||
"tool_calls":[{
|
||||
"id":"call_abc123",
|
||||
"type":"function",
|
||||
"function":{"name":"shell","arguments":"{\"command\":\"ls\"}"}
|
||||
}]
|
||||
}}],"usage":{"prompt_tokens":50,"completion_tokens":25}}"#;
|
||||
let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
|
||||
let message = resp.choices.into_iter().next().unwrap().message;
|
||||
let parsed = AzureOpenAiProvider::parse_native_response(message);
|
||||
assert_eq!(parsed.text.as_deref(), Some("Let me check"));
|
||||
assert_eq!(parsed.tool_calls.len(), 1);
|
||||
assert_eq!(parsed.tool_calls[0].id, "call_abc123");
|
||||
assert_eq!(parsed.tool_calls[0].name, "shell");
|
||||
assert!(parsed.tool_calls[0].arguments.contains("ls"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_response_without_id_generates_uuid() {
|
||||
let json = r#"{"choices":[{"message":{
|
||||
"content":null,
|
||||
"tool_calls":[{
|
||||
"function":{"name":"test","arguments":"{}"}
|
||||
}]
|
||||
}}]}"#;
|
||||
let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
|
||||
let message = resp.choices.into_iter().next().unwrap().message;
|
||||
let parsed = AzureOpenAiProvider::parse_native_response(message);
|
||||
assert_eq!(parsed.tool_calls.len(), 1);
|
||||
assert!(!parsed.tool_calls[0].id.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_tools_fails_without_key() {
|
||||
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
|
||||
let messages = vec![ChatMessage::user("hello".to_string())];
|
||||
let tools = vec![serde_json::json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "shell",
|
||||
"description": "Run a shell command",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": { "type": "string" }
|
||||
},
|
||||
"required": ["command"]
|
||||
}
|
||||
}
|
||||
})];
|
||||
let result = p.chat_with_tools(&messages, &tools, "gpt-4o", 0.7).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("API key not set"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn native_response_parses_usage() {
|
||||
let json = r#"{
|
||||
"choices": [{"message": {"content": "Hello"}}],
|
||||
"usage": {"prompt_tokens": 100, "completion_tokens": 50}
|
||||
}"#;
|
||||
let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
|
||||
let usage = resp.usage.unwrap();
|
||||
assert_eq!(usage.prompt_tokens, Some(100));
|
||||
assert_eq!(usage.completion_tokens, Some(50));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn capabilities_reports_native_tools_and_vision() {
|
||||
let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
|
||||
let caps = <AzureOpenAiProvider as Provider>::capabilities(&p);
|
||||
assert!(caps.native_tool_calling);
|
||||
assert!(caps.vision);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_native_tools_returns_true() {
|
||||
let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
|
||||
assert!(p.supports_native_tools());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_vision_returns_true() {
|
||||
let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
|
||||
assert!(p.supports_vision());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn warmup_is_noop() {
|
||||
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
|
||||
let result = p.warmup().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_api_version_stored() {
|
||||
let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", Some("2025-01-01"));
|
||||
assert_eq!(p.api_version, "2025-01-01");
|
||||
}
|
||||
}
|
||||
1847
third_party/zeroclaw/src/providers/bedrock.rs
vendored
Normal file
1847
third_party/zeroclaw/src/providers/bedrock.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
523
third_party/zeroclaw/src/providers/claude_code.rs
vendored
Normal file
523
third_party/zeroclaw/src/providers/claude_code.rs
vendored
Normal file
@@ -0,0 +1,523 @@
|
||||
//! Claude Code headless CLI provider.
|
||||
//!
|
||||
//! Integrates with the Claude Code CLI, spawning the `claude` binary
|
||||
//! as a subprocess for each inference request. This allows using Claude's AI
|
||||
//! models without an interactive UI session.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! The `claude` binary must be available in `PATH`, or its location must be
|
||||
//! set via the `CLAUDE_CODE_PATH` environment variable.
|
||||
//!
|
||||
//! Claude Code is invoked as:
|
||||
//! ```text
|
||||
//! claude --print -
|
||||
//! ```
|
||||
//! with prompt content written to stdin.
|
||||
//!
|
||||
//! # Limitations
|
||||
//!
|
||||
//! - **System prompt**: The system prompt is prepended to the user message with a
|
||||
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
|
||||
//! - **Temperature**: The CLI does not expose a temperature parameter.
|
||||
//! Only default values are accepted; custom values return an explicit error.
|
||||
//!
|
||||
//! # Authentication
|
||||
//!
|
||||
//! Authentication is handled by Claude Code itself (its own credential store).
|
||||
//! No explicit API key is required by this provider.
|
||||
//!
|
||||
//! # Environment variables
|
||||
//!
|
||||
//! - `CLAUDE_CODE_PATH` — override the path to the `claude` binary (default: `"claude"`)
|
||||
|
||||
use crate::providers::traits::{ChatMessage, ChatRequest, ChatResponse, Provider, TokenUsage};
|
||||
use async_trait::async_trait;
|
||||
use std::path::PathBuf;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
/// Environment variable for overriding the path to the `claude` binary.
|
||||
pub const CLAUDE_CODE_PATH_ENV: &str = "CLAUDE_CODE_PATH";
|
||||
|
||||
/// Default `claude` binary name (resolved via `PATH`).
|
||||
const DEFAULT_CLAUDE_CODE_BINARY: &str = "claude";
|
||||
|
||||
/// Model name used to signal "use the provider's own default model".
|
||||
const DEFAULT_MODEL_MARKER: &str = "default";
|
||||
/// Claude Code requests are bounded to avoid hung subprocesses.
|
||||
const CLAUDE_CODE_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
/// Avoid leaking oversized stderr payloads.
|
||||
const MAX_CLAUDE_CODE_STDERR_CHARS: usize = 512;
|
||||
/// The CLI does not support sampling controls; allow only baseline defaults.
|
||||
const CLAUDE_CODE_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
|
||||
const TEMP_EPSILON: f64 = 1e-9;
|
||||
|
||||
/// Provider that invokes the Claude Code CLI as a subprocess.
|
||||
///
|
||||
/// Each inference request spawns a fresh `claude` process. This is the
|
||||
/// non-interactive approach: the process handles the prompt and exits.
|
||||
pub struct ClaudeCodeProvider {
|
||||
/// Path to the `claude` binary.
|
||||
binary_path: PathBuf,
|
||||
}
|
||||
|
||||
impl ClaudeCodeProvider {
|
||||
/// Create a new `ClaudeCodeProvider`.
|
||||
///
|
||||
/// The binary path is resolved from `CLAUDE_CODE_PATH` env var if set,
|
||||
/// otherwise defaults to `"claude"` (found via `PATH`).
|
||||
pub fn new() -> Self {
|
||||
let binary_path = std::env::var(CLAUDE_CODE_PATH_ENV)
|
||||
.ok()
|
||||
.filter(|path| !path.trim().is_empty())
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from(DEFAULT_CLAUDE_CODE_BINARY));
|
||||
|
||||
Self { binary_path }
|
||||
}
|
||||
|
||||
/// Returns true if the model argument should be forwarded to the CLI.
|
||||
fn should_forward_model(model: &str) -> bool {
|
||||
let trimmed = model.trim();
|
||||
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
|
||||
}
|
||||
|
||||
fn supports_temperature(temperature: f64) -> bool {
|
||||
CLAUDE_CODE_SUPPORTED_TEMPERATURES
|
||||
.iter()
|
||||
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
|
||||
}
|
||||
|
||||
fn validate_temperature(temperature: f64) -> anyhow::Result<f64> {
|
||||
if !temperature.is_finite() {
|
||||
anyhow::bail!("Claude Code provider received non-finite temperature value");
|
||||
}
|
||||
if Self::supports_temperature(temperature) {
|
||||
return Ok(temperature);
|
||||
}
|
||||
// Clamp to the nearest supported value — the CLI ignores temperature
|
||||
// anyway, so a hard error just blocks callers like memory consolidation
|
||||
// that legitimately request low temperatures.
|
||||
let clamped = *CLAUDE_CODE_SUPPORTED_TEMPERATURES
|
||||
.iter()
|
||||
.min_by(|a, b| {
|
||||
(temperature - **a)
|
||||
.abs()
|
||||
.partial_cmp(&(temperature - **b).abs())
|
||||
.unwrap()
|
||||
})
|
||||
.unwrap();
|
||||
tracing::debug!(
|
||||
requested = temperature,
|
||||
clamped = clamped,
|
||||
"Clamped unsupported temperature to nearest Claude Code CLI value"
|
||||
);
|
||||
Ok(clamped)
|
||||
}
|
||||
|
||||
fn redact_stderr(stderr: &[u8]) -> String {
|
||||
let text = String::from_utf8_lossy(stderr);
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
if trimmed.chars().count() <= MAX_CLAUDE_CODE_STDERR_CHARS {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
let clipped: String = trimmed.chars().take(MAX_CLAUDE_CODE_STDERR_CHARS).collect();
|
||||
format!("{clipped}...")
|
||||
}
|
||||
|
||||
/// Invoke the claude binary with the given prompt and optional model.
|
||||
/// Returns the trimmed stdout output as the assistant response.
|
||||
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
|
||||
let mut cmd = Command::new(&self.binary_path);
|
||||
cmd.arg("--print");
|
||||
|
||||
if Self::should_forward_model(model) {
|
||||
cmd.arg("--model").arg(model);
|
||||
}
|
||||
|
||||
// Read prompt from stdin to avoid exposing sensitive content in process args.
|
||||
cmd.arg("-");
|
||||
cmd.kill_on_drop(true);
|
||||
cmd.stdin(std::process::Stdio::piped());
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
|
||||
let mut child = cmd.spawn().map_err(|err| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to spawn Claude Code binary at {}: {err}. \
|
||||
Ensure `claude` is installed and in PATH, or set CLAUDE_CODE_PATH.",
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
stdin.write_all(message.as_bytes()).await.map_err(|err| {
|
||||
anyhow::anyhow!("Failed to write prompt to Claude Code stdin: {err}")
|
||||
})?;
|
||||
stdin.shutdown().await.map_err(|err| {
|
||||
anyhow::anyhow!("Failed to finalize Claude Code stdin stream: {err}")
|
||||
})?;
|
||||
}
|
||||
|
||||
let output = timeout(CLAUDE_CODE_REQUEST_TIMEOUT, child.wait_with_output())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
anyhow::anyhow!(
|
||||
"Claude Code request timed out after {:?} (binary: {})",
|
||||
CLAUDE_CODE_REQUEST_TIMEOUT,
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?
|
||||
.map_err(|err| anyhow::anyhow!("Claude Code process failed: {err}"))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let code = output.status.code().unwrap_or(-1);
|
||||
let stderr_excerpt = Self::redact_stderr(&output.stderr);
|
||||
let stderr_note = if stderr_excerpt.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" Stderr: {stderr_excerpt}")
|
||||
};
|
||||
anyhow::bail!(
|
||||
"Claude Code exited with non-zero status {code}. \
|
||||
Check that Claude Code is authenticated and the CLI is supported.{stderr_note}"
|
||||
);
|
||||
}
|
||||
|
||||
let text = String::from_utf8(output.stdout)
|
||||
.map_err(|err| anyhow::anyhow!("Claude Code produced non-UTF-8 output: {err}"))?;
|
||||
|
||||
Ok(text.trim().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClaudeCodeProvider {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for ClaudeCodeProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Self::validate_temperature(temperature)?;
|
||||
|
||||
let full_message = match system_prompt {
|
||||
Some(system) if !system.is_empty() => {
|
||||
format!("{system}\n\n{message}")
|
||||
}
|
||||
_ => message.to_string(),
|
||||
};
|
||||
|
||||
self.invoke_cli(&full_message, model).await
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Self::validate_temperature(temperature)?;
|
||||
|
||||
// Separate system prompt from conversation messages.
|
||||
let system = messages
|
||||
.iter()
|
||||
.find(|m| m.role == "system")
|
||||
.map(|m| m.content.as_str());
|
||||
|
||||
// Build conversation turns (skip system messages).
|
||||
let turns: Vec<&ChatMessage> = messages.iter().filter(|m| m.role != "system").collect();
|
||||
|
||||
// If there's only one user message, use the simple path.
|
||||
if turns.len() <= 1 {
|
||||
let last_user = turns.first().map(|m| m.content.as_str()).unwrap_or("");
|
||||
let full_message = match system {
|
||||
Some(s) if !s.is_empty() => format!("{s}\n\n{last_user}"),
|
||||
_ => last_user.to_string(),
|
||||
};
|
||||
return self.invoke_cli(&full_message, model).await;
|
||||
}
|
||||
|
||||
// Format multi-turn conversation into a single prompt.
|
||||
let mut parts = Vec::new();
|
||||
if let Some(s) = system {
|
||||
if !s.is_empty() {
|
||||
parts.push(format!("[system]\n{s}"));
|
||||
}
|
||||
}
|
||||
for msg in &turns {
|
||||
let label = match msg.role.as_str() {
|
||||
"user" => "[user]",
|
||||
"assistant" => "[assistant]",
|
||||
other => other,
|
||||
};
|
||||
parts.push(format!("{label}\n{}", msg.content));
|
||||
}
|
||||
parts.push("[assistant]".to_string());
|
||||
|
||||
let full_message = parts.join("\n\n");
|
||||
self.invoke_cli(&full_message, model).await
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let text = self
|
||||
.chat_with_history(request.messages, model, temperature)
|
||||
.await?;
|
||||
|
||||
Ok(ChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
usage: Some(TokenUsage::default()),
|
||||
reasoning_content: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.expect("env lock poisoned")
|
||||
}
|
||||
|
||||
/// Serialize tests that spawn the echo-provider script.
|
||||
///
|
||||
/// On Linux, writing a shell script and exec'ing it from parallel threads
|
||||
/// can trigger `ETXTBSY` ("Text file busy") even with unique file paths,
|
||||
/// because the kernel briefly holds `deny_write_access` on the interpreter
|
||||
/// page cache. Serializing these tests eliminates the race.
|
||||
///
|
||||
/// Uses `tokio::sync::Mutex` so the guard can be held across `.await`.
|
||||
fn script_mutex() -> &'static tokio::sync::Mutex<()> {
|
||||
static LOCK: OnceLock<tokio::sync::Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| tokio::sync::Mutex::new(()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
|
||||
std::env::set_var(CLAUDE_CODE_PATH_ENV, "/usr/local/bin/claude");
|
||||
let provider = ClaudeCodeProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/claude"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(CLAUDE_CODE_PATH_ENV, v),
|
||||
None => std::env::remove_var(CLAUDE_CODE_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_defaults_to_claude() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
|
||||
std::env::remove_var(CLAUDE_CODE_PATH_ENV);
|
||||
let provider = ClaudeCodeProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("claude"));
|
||||
if let Some(v) = orig {
|
||||
std::env::set_var(CLAUDE_CODE_PATH_ENV, v);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_ignores_blank_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
|
||||
std::env::set_var(CLAUDE_CODE_PATH_ENV, " ");
|
||||
let provider = ClaudeCodeProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("claude"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(CLAUDE_CODE_PATH_ENV, v),
|
||||
None => std::env::remove_var(CLAUDE_CODE_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_forward_model_standard() {
|
||||
assert!(ClaudeCodeProvider::should_forward_model(
|
||||
"claude-sonnet-4-20250514"
|
||||
));
|
||||
assert!(ClaudeCodeProvider::should_forward_model(
|
||||
"claude-3.5-sonnet"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_not_forward_default_model() {
|
||||
assert!(!ClaudeCodeProvider::should_forward_model(
|
||||
DEFAULT_MODEL_MARKER
|
||||
));
|
||||
assert!(!ClaudeCodeProvider::should_forward_model(""));
|
||||
assert!(!ClaudeCodeProvider::should_forward_model(" "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_allows_defaults() {
|
||||
assert!(ClaudeCodeProvider::validate_temperature(0.7).is_ok());
|
||||
assert!(ClaudeCodeProvider::validate_temperature(1.0).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_clamps_custom_value() {
|
||||
let clamped = ClaudeCodeProvider::validate_temperature(0.2).unwrap();
|
||||
assert!((clamped - 0.7).abs() < 1e-9, "0.2 should clamp to 0.7");
|
||||
|
||||
let clamped = ClaudeCodeProvider::validate_temperature(0.9).unwrap();
|
||||
assert!((clamped - 1.0).abs() < 1e-9, "0.9 should clamp to 1.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_rejects_non_finite() {
|
||||
assert!(ClaudeCodeProvider::validate_temperature(f64::NAN).is_err());
|
||||
assert!(ClaudeCodeProvider::validate_temperature(f64::INFINITY).is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invoke_missing_binary_returns_error() {
|
||||
let provider = ClaudeCodeProvider {
|
||||
binary_path: PathBuf::from("/nonexistent/path/to/claude"),
|
||||
};
|
||||
let result = provider.invoke_cli("hello", "default").await;
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
msg.contains("Failed to spawn Claude Code binary"),
|
||||
"unexpected error message: {msg}"
|
||||
);
|
||||
}
|
||||
|
||||
/// Helper: create a provider that uses a shell script echoing stdin back.
|
||||
/// The script ignores CLI flags (`--print`, `--model`, `-`) and just cats stdin.
|
||||
///
|
||||
/// Each invocation places the script in its own unique directory and writes
|
||||
/// the file atomically via `std::fs::write` to avoid `ETXTBSY` ("Text file
|
||||
/// busy") races that occur when parallel test threads create and exec
|
||||
/// scripts concurrently on the same filesystem.
|
||||
fn echo_provider() -> ClaudeCodeProvider {
|
||||
static SCRIPT_ID: AtomicUsize = AtomicUsize::new(0);
|
||||
let script_id = SCRIPT_ID.fetch_add(1, Ordering::Relaxed);
|
||||
let dir = std::env::temp_dir().join(format!(
|
||||
"zeroclaw_test_claude_code_{}_{}",
|
||||
std::process::id(),
|
||||
script_id
|
||||
));
|
||||
std::fs::create_dir_all(&dir).unwrap();
|
||||
|
||||
let path = dir.join("fake_claude.sh");
|
||||
std::fs::write(&path, "#!/bin/sh\ncat /dev/stdin\n").unwrap();
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o755)).unwrap();
|
||||
}
|
||||
ClaudeCodeProvider { binary_path: path }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn echo_provider_uses_unique_script_paths() {
|
||||
let first = echo_provider();
|
||||
let second = echo_provider();
|
||||
assert_ne!(first.binary_path, second.binary_path);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_single_user_message() {
|
||||
let _lock = script_mutex().lock().await;
|
||||
let provider = echo_provider();
|
||||
let messages = vec![ChatMessage::user("hello")];
|
||||
let result = provider
|
||||
.chat_with_history(&messages, "default", 1.0)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_single_user_with_system() {
|
||||
let _lock = script_mutex().lock().await;
|
||||
let provider = echo_provider();
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful."),
|
||||
ChatMessage::user("hello"),
|
||||
];
|
||||
let result = provider
|
||||
.chat_with_history(&messages, "default", 1.0)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "You are helpful.\n\nhello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_multi_turn_includes_all_messages() {
|
||||
let _lock = script_mutex().lock().await;
|
||||
let provider = echo_provider();
|
||||
let messages = vec![
|
||||
ChatMessage::system("Be concise."),
|
||||
ChatMessage::user("What is 2+2?"),
|
||||
ChatMessage::assistant("4"),
|
||||
ChatMessage::user("And 3+3?"),
|
||||
];
|
||||
let result = provider
|
||||
.chat_with_history(&messages, "default", 1.0)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.contains("[system]\nBe concise."));
|
||||
assert!(result.contains("[user]\nWhat is 2+2?"));
|
||||
assert!(result.contains("[assistant]\n4"));
|
||||
assert!(result.contains("[user]\nAnd 3+3?"));
|
||||
assert!(result.ends_with("[assistant]"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_multi_turn_without_system() {
|
||||
let _lock = script_mutex().lock().await;
|
||||
let provider = echo_provider();
|
||||
let messages = vec![
|
||||
ChatMessage::user("hi"),
|
||||
ChatMessage::assistant("hello"),
|
||||
ChatMessage::user("bye"),
|
||||
];
|
||||
let result = provider
|
||||
.chat_with_history(&messages, "default", 1.0)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.contains("[system]"));
|
||||
assert!(result.contains("[user]\nhi"));
|
||||
assert!(result.contains("[assistant]\nhello"));
|
||||
assert!(result.contains("[user]\nbye"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_clamps_bad_temperature() {
|
||||
let _lock = script_mutex().lock().await;
|
||||
let provider = echo_provider();
|
||||
let messages = vec![ChatMessage::user("test")];
|
||||
let result = provider.chat_with_history(&messages, "default", 0.5).await;
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"unsupported temperature should be clamped, not rejected"
|
||||
);
|
||||
}
|
||||
}
|
||||
3989
third_party/zeroclaw/src/providers/compatible.rs
vendored
Normal file
3989
third_party/zeroclaw/src/providers/compatible.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
822
third_party/zeroclaw/src/providers/copilot.rs
vendored
Normal file
822
third_party/zeroclaw/src/providers/copilot.rs
vendored
Normal file
@@ -0,0 +1,822 @@
|
||||
//! GitHub Copilot provider with OAuth device-flow authentication.
|
||||
//!
|
||||
//! Authenticates via GitHub's device code flow (same as VS Code Copilot),
|
||||
//! then exchanges the OAuth token for short-lived Copilot API keys.
|
||||
//! Tokens are cached to disk and auto-refreshed.
|
||||
//!
|
||||
//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and
|
||||
//! editor headers. This is the same approach used by LiteLLM, Codex CLI,
|
||||
//! and other third-party Copilot integrations. The Copilot token endpoint is
|
||||
//! private; there is no public OAuth scope or app registration for it.
|
||||
//! GitHub could change or revoke this at any time, which would break all
|
||||
//! third-party integrations simultaneously.
|
||||
|
||||
use crate::providers::traits::{
|
||||
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||
Provider, TokenUsage, ToolCall as ProviderToolCall,
|
||||
};
|
||||
use crate::tools::ToolSpec;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::warn;
|
||||
|
||||
/// GitHub OAuth client ID for Copilot (VS Code extension).
|
||||
const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
|
||||
const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
|
||||
const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
|
||||
const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
|
||||
const DEFAULT_API: &str = "https://api.githubcopilot.com";
|
||||
|
||||
// ── Token types ──────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceCodeResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
#[serde(default = "default_interval")]
|
||||
interval: u64,
|
||||
#[serde(default = "default_expires_in")]
|
||||
expires_in: u64,
|
||||
}
|
||||
|
||||
fn default_interval() -> u64 {
|
||||
5
|
||||
}
|
||||
|
||||
fn default_expires_in() -> u64 {
|
||||
900
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AccessTokenResponse {
|
||||
access_token: Option<String>,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ApiKeyInfo {
|
||||
token: String,
|
||||
expires_at: i64,
|
||||
#[serde(default)]
|
||||
endpoints: Option<ApiEndpoints>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ApiEndpoints {
|
||||
api: Option<String>,
|
||||
}
|
||||
|
||||
struct CachedApiKey {
|
||||
token: String,
|
||||
api_endpoint: String,
|
||||
expires_at: i64,
|
||||
}
|
||||
|
||||
// ── Chat completions types ───────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ApiChatRequest<'a> {
|
||||
model: String,
|
||||
messages: Vec<ApiMessage>,
|
||||
temperature: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<NativeToolSpec<'a>>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_choice: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ApiMessage {
|
||||
role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<ApiContent>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolSpec<'a> {
|
||||
#[serde(rename = "type")]
|
||||
kind: &'static str,
|
||||
function: NativeToolFunctionSpec<'a>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolFunctionSpec<'a> {
|
||||
name: &'a str,
|
||||
description: &'a str,
|
||||
parameters: &'a serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
kind: Option<String>,
|
||||
function: NativeFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeFunctionCall {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
/// Multi-part content for vision messages (OpenAI format).
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(untagged)]
|
||||
enum ApiContent {
|
||||
Text(String),
|
||||
Parts(Vec<ContentPart>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum ContentPart {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image_url")]
|
||||
ImageUrl { image_url: ImageUrlDetail },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct ImageUrlDetail {
|
||||
url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiChatResponse {
|
||||
choices: Vec<Choice>,
|
||||
#[serde(default)]
|
||||
usage: Option<UsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct UsageInfo {
|
||||
#[serde(default)]
|
||||
prompt_tokens: Option<u64>,
|
||||
#[serde(default)]
|
||||
completion_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
// ── Provider ─────────────────────────────────────────────────────
|
||||
|
||||
/// GitHub Copilot provider with automatic OAuth and token refresh.
|
||||
///
|
||||
/// On first use, prompts the user to visit github.com/login/device.
|
||||
/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed
|
||||
/// automatically.
|
||||
pub struct CopilotProvider {
|
||||
github_token: Option<String>,
|
||||
/// Mutex ensures only one caller refreshes tokens at a time,
|
||||
/// preventing duplicate device flow prompts or redundant API calls.
|
||||
refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
|
||||
token_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl CopilotProvider {
|
||||
pub fn new(github_token: Option<&str>) -> Self {
|
||||
let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
|
||||
.map(|dir| dir.config_dir().join("copilot"))
|
||||
.unwrap_or_else(|| {
|
||||
// Fall back to a user-specific temp directory to avoid
|
||||
// shared-directory symlink attacks.
|
||||
let user = std::env::var("USER")
|
||||
.or_else(|_| std::env::var("USERNAME"))
|
||||
.unwrap_or_else(|_| "unknown".to_string());
|
||||
std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
|
||||
});
|
||||
|
||||
if let Err(err) = std::fs::create_dir_all(&token_dir) {
|
||||
warn!(
|
||||
"Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
|
||||
token_dir
|
||||
);
|
||||
} else {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
if let Err(err) =
|
||||
std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
|
||||
{
|
||||
warn!(
|
||||
"Failed to set Copilot token directory permissions on {:?}: {err}",
|
||||
token_dir
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
github_token: github_token
|
||||
.filter(|token| !token.is_empty())
|
||||
.map(String::from),
|
||||
refresh_lock: Arc::new(Mutex::new(None)),
|
||||
token_dir,
|
||||
}
|
||||
}
|
||||
|
||||
fn http_client(&self) -> Client {
|
||||
crate::config::build_runtime_proxy_client_with_timeouts("provider.copilot", 120, 10)
|
||||
}
|
||||
|
||||
/// Required headers for Copilot API requests (editor identification).
|
||||
const COPILOT_HEADERS: [(&str, &str); 4] = [
|
||||
("Editor-Version", "vscode/1.85.1"),
|
||||
("Editor-Plugin-Version", "copilot/1.155.0"),
|
||||
("User-Agent", "GithubCopilot/1.155.0"),
|
||||
("Accept", "application/json"),
|
||||
];
|
||||
|
||||
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec<'_>>> {
|
||||
tools.map(|items| {
|
||||
items
|
||||
.iter()
|
||||
.map(|tool| NativeToolSpec {
|
||||
kind: "function",
|
||||
function: NativeToolFunctionSpec {
|
||||
name: &tool.name,
|
||||
description: &tool.description,
|
||||
parameters: &tool.parameters,
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
/// Convert message content to API format, with multi-part support for
|
||||
/// user messages containing `[IMAGE:...]` markers.
|
||||
fn to_api_content(role: &str, content: &str) -> Option<ApiContent> {
|
||||
if role != "user" {
|
||||
return Some(ApiContent::Text(content.to_string()));
|
||||
}
|
||||
|
||||
let (cleaned_text, image_refs) = crate::multimodal::parse_image_markers(content);
|
||||
if image_refs.is_empty() {
|
||||
return Some(ApiContent::Text(content.to_string()));
|
||||
}
|
||||
|
||||
let mut parts = Vec::with_capacity(image_refs.len() + 1);
|
||||
let trimmed = cleaned_text.trim();
|
||||
if !trimmed.is_empty() {
|
||||
parts.push(ContentPart::Text {
|
||||
text: trimmed.to_string(),
|
||||
});
|
||||
}
|
||||
for image_ref in image_refs {
|
||||
parts.push(ContentPart::ImageUrl {
|
||||
image_url: ImageUrlDetail { url: image_ref },
|
||||
});
|
||||
}
|
||||
|
||||
Some(ApiContent::Parts(parts))
|
||||
}
|
||||
|
||||
fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|message| {
|
||||
if message.role == "assistant" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||
if let Ok(parsed_calls) =
|
||||
serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
|
||||
{
|
||||
let tool_calls = parsed_calls
|
||||
.into_iter()
|
||||
.map(|tool_call| NativeToolCall {
|
||||
id: Some(tool_call.id),
|
||||
kind: Some("function".to_string()),
|
||||
function: NativeFunctionCall {
|
||||
name: tool_call.name,
|
||||
arguments: tool_call.arguments,
|
||||
},
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(|s| ApiContent::Text(s.to_string()));
|
||||
|
||||
return ApiMessage {
|
||||
role: "assistant".to_string(),
|
||||
content,
|
||||
tool_call_id: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if message.role == "tool" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||
let tool_call_id = value
|
||||
.get("tool_call_id")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(|s| ApiContent::Text(s.to_string()));
|
||||
|
||||
return ApiMessage {
|
||||
role: "tool".to_string(),
|
||||
content,
|
||||
tool_call_id,
|
||||
tool_calls: None,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
ApiMessage {
|
||||
role: message.role.clone(),
|
||||
content: Self::to_api_content(&message.role, &message.content),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Send a chat completions request with required Copilot headers.
|
||||
async fn send_chat_request(
|
||||
&self,
|
||||
messages: Vec<ApiMessage>,
|
||||
tools: Option<&[ToolSpec]>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let (token, endpoint) = self.get_api_key().await?;
|
||||
let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
|
||||
|
||||
let native_tools = Self::convert_tools(tools);
|
||||
let request = ApiChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
temperature,
|
||||
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools: native_tools,
|
||||
};
|
||||
|
||||
let mut req = self
|
||||
.http_client()
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&request);
|
||||
|
||||
for (header, value) in &Self::COPILOT_HEADERS {
|
||||
req = req.header(*header, *value);
|
||||
}
|
||||
|
||||
let response = req.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("GitHub Copilot", response).await);
|
||||
}
|
||||
|
||||
let api_response: ApiChatResponse = response.json().await?;
|
||||
let usage = api_response.usage.map(|u| TokenUsage {
|
||||
input_tokens: u.prompt_tokens,
|
||||
output_tokens: u.completion_tokens,
|
||||
cached_input_tokens: None,
|
||||
});
|
||||
let choice = api_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?;
|
||||
|
||||
let tool_calls = choice
|
||||
.message
|
||||
.tool_calls
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|tool_call| ProviderToolCall {
|
||||
id: tool_call
|
||||
.id
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
name: tool_call.function.name,
|
||||
arguments: tool_call.function.arguments,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ProviderChatResponse {
|
||||
text: choice.message.content,
|
||||
tool_calls,
|
||||
usage,
|
||||
reasoning_content: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a valid Copilot API key, refreshing or re-authenticating as needed.
|
||||
/// Uses a Mutex to ensure only one caller refreshes at a time.
|
||||
async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
|
||||
let mut cached = self.refresh_lock.lock().await;
|
||||
|
||||
if let Some(cached_key) = cached.as_ref() {
|
||||
if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at {
|
||||
return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(info) = self.load_api_key_from_disk().await {
|
||||
if chrono::Utc::now().timestamp() + 120 < info.expires_at {
|
||||
let endpoint = info
|
||||
.endpoints
|
||||
.as_ref()
|
||||
.and_then(|e| e.api.clone())
|
||||
.unwrap_or_else(|| DEFAULT_API.to_string());
|
||||
let token = info.token;
|
||||
|
||||
*cached = Some(CachedApiKey {
|
||||
token: token.clone(),
|
||||
api_endpoint: endpoint.clone(),
|
||||
expires_at: info.expires_at,
|
||||
});
|
||||
return Ok((token, endpoint));
|
||||
}
|
||||
}
|
||||
|
||||
let access_token = self.get_github_access_token().await?;
|
||||
let api_key_info = self.exchange_for_api_key(&access_token).await?;
|
||||
self.save_api_key_to_disk(&api_key_info).await;
|
||||
|
||||
let endpoint = api_key_info
|
||||
.endpoints
|
||||
.as_ref()
|
||||
.and_then(|e| e.api.clone())
|
||||
.unwrap_or_else(|| DEFAULT_API.to_string());
|
||||
|
||||
*cached = Some(CachedApiKey {
|
||||
token: api_key_info.token.clone(),
|
||||
api_endpoint: endpoint.clone(),
|
||||
expires_at: api_key_info.expires_at,
|
||||
});
|
||||
|
||||
Ok((api_key_info.token, endpoint))
|
||||
}
|
||||
|
||||
/// Get a GitHub access token from config, cache, or device flow.
|
||||
async fn get_github_access_token(&self) -> anyhow::Result<String> {
|
||||
if let Some(token) = &self.github_token {
|
||||
return Ok(token.clone());
|
||||
}
|
||||
|
||||
let access_token_path = self.token_dir.join("access-token");
|
||||
if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
|
||||
let token = cached.trim();
|
||||
if !token.is_empty() {
|
||||
return Ok(token.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let token = self.device_code_login().await?;
|
||||
write_file_secure(&access_token_path, &token).await;
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// Run GitHub OAuth device code flow.
|
||||
async fn device_code_login(&self) -> anyhow::Result<String> {
|
||||
let response: DeviceCodeResponse = self
|
||||
.http_client()
|
||||
.post(GITHUB_DEVICE_CODE_URL)
|
||||
.header("Accept", "application/json")
|
||||
.json(&serde_json::json!({
|
||||
"client_id": GITHUB_CLIENT_ID,
|
||||
"scope": "read:user"
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
let mut poll_interval = Duration::from_secs(response.interval.max(5));
|
||||
let expires_in = response.expires_in.max(1);
|
||||
let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
|
||||
|
||||
eprintln!(
|
||||
"\nGitHub Copilot authentication is required.\n\
|
||||
Visit: {}\n\
|
||||
Code: {}\n\
|
||||
Waiting for authorization...\n",
|
||||
response.verification_uri, response.user_code
|
||||
);
|
||||
|
||||
while tokio::time::Instant::now() < expires_at {
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
let token_response: AccessTokenResponse = self
|
||||
.http_client()
|
||||
.post(GITHUB_ACCESS_TOKEN_URL)
|
||||
.header("Accept", "application/json")
|
||||
.json(&serde_json::json!({
|
||||
"client_id": GITHUB_CLIENT_ID,
|
||||
"device_code": response.device_code,
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
if let Some(token) = token_response.access_token {
|
||||
eprintln!("Authentication succeeded.\n");
|
||||
return Ok(token);
|
||||
}
|
||||
|
||||
match token_response.error.as_deref() {
|
||||
Some("slow_down") => {
|
||||
poll_interval += Duration::from_secs(5);
|
||||
}
|
||||
Some("authorization_pending") | None => {}
|
||||
Some("expired_token") => {
|
||||
anyhow::bail!("GitHub device authorization expired")
|
||||
}
|
||||
Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!("Timed out waiting for GitHub authorization")
|
||||
}
|
||||
|
||||
/// Exchange a GitHub access token for a Copilot API key.
|
||||
async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
|
||||
let mut request = self.http_client().get(GITHUB_API_KEY_URL);
|
||||
for (header, value) in &Self::COPILOT_HEADERS {
|
||||
request = request.header(*header, *value);
|
||||
}
|
||||
request = request.header("Authorization", format!("token {access_token}"));
|
||||
|
||||
let response = request.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
let sanitized = super::sanitize_api_error(&body);
|
||||
|
||||
if status.as_u16() == 401 || status.as_u16() == 403 {
|
||||
let access_token_path = self.token_dir.join("access-token");
|
||||
tokio::fs::remove_file(&access_token_path).await.ok();
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"Failed to get Copilot API key ({status}): {sanitized}. \
|
||||
Ensure your GitHub account has an active Copilot subscription."
|
||||
);
|
||||
}
|
||||
|
||||
let info: ApiKeyInfo = response.json().await?;
|
||||
Ok(info)
|
||||
}
|
||||
|
||||
async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
|
||||
let path = self.token_dir.join("api-key.json");
|
||||
let data = tokio::fs::read_to_string(&path).await.ok()?;
|
||||
serde_json::from_str(&data).ok()
|
||||
}
|
||||
|
||||
async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
|
||||
let path = self.token_dir.join("api-key.json");
|
||||
if let Ok(json) = serde_json::to_string_pretty(info) {
|
||||
write_file_secure(&path, &json).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a file with 0600 permissions (owner read/write only).
|
||||
/// Uses `spawn_blocking` to avoid blocking the async runtime.
|
||||
async fn write_file_secure(path: &Path, content: &str) {
|
||||
let path = path.to_path_buf();
|
||||
let content = content.to_string();
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::io::Write;
|
||||
use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
|
||||
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.mode(0o600)
|
||||
.open(&path)?;
|
||||
file.write_all(content.as_bytes())?;
|
||||
|
||||
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
|
||||
Ok::<(), std::io::Error>(())
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
std::fs::write(&path, &content)?;
|
||||
Ok::<(), std::io::Error>(())
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(err)) => warn!("Failed to write secure file: {err}"),
|
||||
Err(err) => warn!("Failed to spawn blocking write: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for CopilotProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut messages = Vec::new();
|
||||
if let Some(system) = system_prompt {
|
||||
messages.push(ApiMessage {
|
||||
role: "system".to_string(),
|
||||
content: Some(ApiContent::Text(system.to_string())),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
}
|
||||
messages.push(ApiMessage {
|
||||
role: "user".to_string(),
|
||||
content: Self::to_api_content("user", message),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
|
||||
let response = self
|
||||
.send_chat_request(messages, None, model, temperature)
|
||||
.await?;
|
||||
Ok(response.text.unwrap_or_default())
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let response = self
|
||||
.send_chat_request(Self::convert_messages(messages), None, model, temperature)
|
||||
.await?;
|
||||
Ok(response.text.unwrap_or_default())
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ProviderChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
self.send_chat_request(
|
||||
Self::convert_messages(request.messages),
|
||||
request.tools,
|
||||
model,
|
||||
temperature,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
let _ = self.get_api_key().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_without_token() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
assert!(provider.github_token.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_with_token() {
|
||||
let provider = CopilotProvider::new(Some("ghp_test"));
|
||||
assert_eq!(provider.github_token.as_deref(), Some("ghp_test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_token_treated_as_none() {
|
||||
let provider = CopilotProvider::new(Some(""));
|
||||
assert!(provider.github_token.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cache_starts_empty() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
let cached = provider.refresh_lock.lock().await;
|
||||
assert!(cached.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copilot_headers_include_required_fields() {
|
||||
let headers = CopilotProvider::COPILOT_HEADERS;
|
||||
assert!(headers
|
||||
.iter()
|
||||
.any(|(header, _)| *header == "Editor-Version"));
|
||||
assert!(headers
|
||||
.iter()
|
||||
.any(|(header, _)| *header == "Editor-Plugin-Version"));
|
||||
assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_interval_and_expiry() {
|
||||
assert_eq!(default_interval(), 5);
|
||||
assert_eq!(default_expires_in(), 900);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_native_tools() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
assert!(provider.supports_native_tools());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_response_parses_usage() {
|
||||
let json = r#"{
|
||||
"choices": [{"message": {"content": "Hello"}}],
|
||||
"usage": {"prompt_tokens": 200, "completion_tokens": 80}
|
||||
}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
let usage = resp.usage.unwrap();
|
||||
assert_eq!(usage.prompt_tokens, Some(200));
|
||||
assert_eq!(usage.completion_tokens, Some(80));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_response_parses_without_usage() {
|
||||
let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.usage.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_api_content_user_with_image_returns_parts() {
|
||||
let content = "describe this [IMAGE:data:image/png;base64,abc123]";
|
||||
let result = CopilotProvider::to_api_content("user", content).unwrap();
|
||||
match result {
|
||||
ApiContent::Parts(parts) => {
|
||||
assert_eq!(parts.len(), 2);
|
||||
assert!(matches!(&parts[0], ContentPart::Text { text } if text == "describe this"));
|
||||
assert!(
|
||||
matches!(&parts[1], ContentPart::ImageUrl { image_url } if image_url.url == "data:image/png;base64,abc123")
|
||||
);
|
||||
}
|
||||
ApiContent::Text(_) => {
|
||||
panic!("expected ApiContent::Parts for user message with image marker")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_api_content_user_plain_returns_text() {
|
||||
let result = CopilotProvider::to_api_content("user", "hello world").unwrap();
|
||||
assert!(matches!(result, ApiContent::Text(ref s) if s == "hello world"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_api_content_non_user_returns_text() {
|
||||
let result = CopilotProvider::to_api_content("system", "you are helpful").unwrap();
|
||||
assert!(matches!(result, ApiContent::Text(ref s) if s == "you are helpful"));
|
||||
|
||||
let result = CopilotProvider::to_api_content("assistant", "sure").unwrap();
|
||||
assert!(matches!(result, ApiContent::Text(ref s) if s == "sure"));
|
||||
}
|
||||
}
|
||||
2274
third_party/zeroclaw/src/providers/gemini.rs
vendored
Normal file
2274
third_party/zeroclaw/src/providers/gemini.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
326
third_party/zeroclaw/src/providers/gemini_cli.rs
vendored
Normal file
326
third_party/zeroclaw/src/providers/gemini_cli.rs
vendored
Normal file
@@ -0,0 +1,326 @@
|
||||
//! Gemini CLI subprocess provider.
|
||||
//!
|
||||
//! Integrates with the Gemini CLI, spawning the `gemini` binary
|
||||
//! as a subprocess for each inference request. This allows using Google's
|
||||
//! Gemini models via the CLI without an interactive UI session.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! The `gemini` binary must be available in `PATH`, or its location must be
|
||||
//! set via the `GEMINI_CLI_PATH` environment variable.
|
||||
//!
|
||||
//! Gemini CLI is invoked as:
|
||||
//! ```text
|
||||
//! gemini --print -
|
||||
//! ```
|
||||
//! with prompt content written to stdin.
|
||||
//!
|
||||
//! # Limitations
|
||||
//!
|
||||
//! - **Conversation history**: Only the system prompt (if present) and the last
|
||||
//! user message are forwarded. Full multi-turn history is not preserved because
|
||||
//! the CLI accepts a single prompt per invocation.
|
||||
//! - **System prompt**: The system prompt is prepended to the user message with a
|
||||
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
|
||||
//! - **Temperature**: The CLI does not expose a temperature parameter.
|
||||
//! Only default values are accepted; custom values return an explicit error.
|
||||
//!
|
||||
//! # Authentication
|
||||
//!
|
||||
//! Authentication is handled by the Gemini CLI itself (its own credential store).
|
||||
//! No explicit API key is required by this provider.
|
||||
//!
|
||||
//! # Environment variables
|
||||
//!
|
||||
//! - `GEMINI_CLI_PATH` — override the path to the `gemini` binary (default: `"gemini"`)
|
||||
|
||||
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
|
||||
use async_trait::async_trait;
|
||||
use std::path::PathBuf;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
/// Environment variable for overriding the path to the `gemini` binary.
|
||||
pub const GEMINI_CLI_PATH_ENV: &str = "GEMINI_CLI_PATH";
|
||||
|
||||
/// Default `gemini` binary name (resolved via `PATH`).
|
||||
const DEFAULT_GEMINI_CLI_BINARY: &str = "gemini";
|
||||
|
||||
/// Model name used to signal "use the provider's own default model".
|
||||
const DEFAULT_MODEL_MARKER: &str = "default";
|
||||
/// Gemini CLI requests are bounded to avoid hung subprocesses.
|
||||
const GEMINI_CLI_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
/// Avoid leaking oversized stderr payloads.
|
||||
const MAX_GEMINI_CLI_STDERR_CHARS: usize = 512;
|
||||
/// The CLI does not support sampling controls; allow only baseline defaults.
|
||||
const GEMINI_CLI_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
|
||||
const TEMP_EPSILON: f64 = 1e-9;
|
||||
|
||||
/// Provider that invokes the Gemini CLI as a subprocess.
|
||||
///
|
||||
/// Each inference request spawns a fresh `gemini` process. This is the
|
||||
/// non-interactive approach: the process handles the prompt and exits.
|
||||
pub struct GeminiCliProvider {
|
||||
/// Path to the `gemini` binary.
|
||||
binary_path: PathBuf,
|
||||
}
|
||||
|
||||
impl GeminiCliProvider {
|
||||
/// Create a new `GeminiCliProvider`.
|
||||
///
|
||||
/// The binary path is resolved from `GEMINI_CLI_PATH` env var if set,
|
||||
/// otherwise defaults to `"gemini"` (found via `PATH`).
|
||||
pub fn new() -> Self {
|
||||
let binary_path = std::env::var(GEMINI_CLI_PATH_ENV)
|
||||
.ok()
|
||||
.filter(|path| !path.trim().is_empty())
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from(DEFAULT_GEMINI_CLI_BINARY));
|
||||
|
||||
Self { binary_path }
|
||||
}
|
||||
|
||||
/// Returns true if the model argument should be forwarded to the CLI.
|
||||
fn should_forward_model(model: &str) -> bool {
|
||||
let trimmed = model.trim();
|
||||
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
|
||||
}
|
||||
|
||||
fn supports_temperature(temperature: f64) -> bool {
|
||||
GEMINI_CLI_SUPPORTED_TEMPERATURES
|
||||
.iter()
|
||||
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
|
||||
}
|
||||
|
||||
fn validate_temperature(temperature: f64) -> anyhow::Result<()> {
|
||||
if !temperature.is_finite() {
|
||||
anyhow::bail!("Gemini CLI provider received non-finite temperature value");
|
||||
}
|
||||
if !Self::supports_temperature(temperature) {
|
||||
anyhow::bail!(
|
||||
"temperature unsupported by Gemini CLI: {temperature}. \
|
||||
Supported values: 0.7 or 1.0"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn redact_stderr(stderr: &[u8]) -> String {
|
||||
let text = String::from_utf8_lossy(stderr);
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
if trimmed.chars().count() <= MAX_GEMINI_CLI_STDERR_CHARS {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
let clipped: String = trimmed.chars().take(MAX_GEMINI_CLI_STDERR_CHARS).collect();
|
||||
format!("{clipped}...")
|
||||
}
|
||||
|
||||
/// Invoke the gemini binary with the given prompt and optional model.
|
||||
/// Returns the trimmed stdout output as the assistant response.
|
||||
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
|
||||
let mut cmd = Command::new(&self.binary_path);
|
||||
cmd.arg("--print");
|
||||
|
||||
if Self::should_forward_model(model) {
|
||||
cmd.arg("--model").arg(model);
|
||||
}
|
||||
|
||||
// Read prompt from stdin to avoid exposing sensitive content in process args.
|
||||
cmd.arg("-");
|
||||
cmd.kill_on_drop(true);
|
||||
cmd.stdin(std::process::Stdio::piped());
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
|
||||
let mut child = cmd.spawn().map_err(|err| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to spawn Gemini CLI binary at {}: {err}. \
|
||||
Ensure `gemini` is installed and in PATH, or set GEMINI_CLI_PATH.",
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
stdin.write_all(message.as_bytes()).await.map_err(|err| {
|
||||
anyhow::anyhow!("Failed to write prompt to Gemini CLI stdin: {err}")
|
||||
})?;
|
||||
stdin.shutdown().await.map_err(|err| {
|
||||
anyhow::anyhow!("Failed to finalize Gemini CLI stdin stream: {err}")
|
||||
})?;
|
||||
}
|
||||
|
||||
let output = timeout(GEMINI_CLI_REQUEST_TIMEOUT, child.wait_with_output())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
anyhow::anyhow!(
|
||||
"Gemini CLI request timed out after {:?} (binary: {})",
|
||||
GEMINI_CLI_REQUEST_TIMEOUT,
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?
|
||||
.map_err(|err| anyhow::anyhow!("Gemini CLI process failed: {err}"))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let code = output.status.code().unwrap_or(-1);
|
||||
let stderr_excerpt = Self::redact_stderr(&output.stderr);
|
||||
let stderr_note = if stderr_excerpt.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" Stderr: {stderr_excerpt}")
|
||||
};
|
||||
anyhow::bail!(
|
||||
"Gemini CLI exited with non-zero status {code}. \
|
||||
Check that Gemini CLI is authenticated and the CLI is supported.{stderr_note}"
|
||||
);
|
||||
}
|
||||
|
||||
let text = String::from_utf8(output.stdout)
|
||||
.map_err(|err| anyhow::anyhow!("Gemini CLI produced non-UTF-8 output: {err}"))?;
|
||||
|
||||
Ok(text.trim().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for GeminiCliProvider {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for GeminiCliProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Self::validate_temperature(temperature)?;
|
||||
|
||||
let full_message = match system_prompt {
|
||||
Some(system) if !system.is_empty() => {
|
||||
format!("{system}\n\n{message}")
|
||||
}
|
||||
_ => message.to_string(),
|
||||
};
|
||||
|
||||
self.invoke_cli(&full_message, model).await
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let text = self
|
||||
.chat_with_history(request.messages, model, temperature)
|
||||
.await?;
|
||||
|
||||
Ok(ChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
usage: Some(TokenUsage::default()),
|
||||
reasoning_content: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.expect("env lock poisoned")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
|
||||
std::env::set_var(GEMINI_CLI_PATH_ENV, "/usr/local/bin/gemini");
|
||||
let provider = GeminiCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/gemini"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(GEMINI_CLI_PATH_ENV, v),
|
||||
None => std::env::remove_var(GEMINI_CLI_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_defaults_to_gemini() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
|
||||
std::env::remove_var(GEMINI_CLI_PATH_ENV);
|
||||
let provider = GeminiCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("gemini"));
|
||||
if let Some(v) = orig {
|
||||
std::env::set_var(GEMINI_CLI_PATH_ENV, v);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_ignores_blank_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
|
||||
std::env::set_var(GEMINI_CLI_PATH_ENV, " ");
|
||||
let provider = GeminiCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("gemini"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(GEMINI_CLI_PATH_ENV, v),
|
||||
None => std::env::remove_var(GEMINI_CLI_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_forward_model_standard() {
|
||||
assert!(GeminiCliProvider::should_forward_model("gemini-2.5-pro"));
|
||||
assert!(GeminiCliProvider::should_forward_model("gemini-2.5-flash"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_not_forward_default_model() {
|
||||
assert!(!GeminiCliProvider::should_forward_model(
|
||||
DEFAULT_MODEL_MARKER
|
||||
));
|
||||
assert!(!GeminiCliProvider::should_forward_model(""));
|
||||
assert!(!GeminiCliProvider::should_forward_model(" "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_allows_defaults() {
|
||||
assert!(GeminiCliProvider::validate_temperature(0.7).is_ok());
|
||||
assert!(GeminiCliProvider::validate_temperature(1.0).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_rejects_custom_value() {
|
||||
let err = GeminiCliProvider::validate_temperature(0.2).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("temperature unsupported by Gemini CLI"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invoke_missing_binary_returns_error() {
|
||||
let provider = GeminiCliProvider {
|
||||
binary_path: PathBuf::from("/nonexistent/path/to/gemini"),
|
||||
};
|
||||
let result = provider.invoke_cli("hello", "default").await;
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
msg.contains("Failed to spawn Gemini CLI binary"),
|
||||
"unexpected error message: {msg}"
|
||||
);
|
||||
}
|
||||
}
|
||||
361
third_party/zeroclaw/src/providers/glm.rs
vendored
Normal file
361
third_party/zeroclaw/src/providers/glm.rs
vendored
Normal file
@@ -0,0 +1,361 @@
|
||||
//! Zhipu GLM provider with JWT authentication.
|
||||
//! The GLM API requires JWT tokens generated from the `id.secret` API key format
|
||||
//! with a custom `sign_type: "SIGN"` header, and uses `/v4/chat/completions`.
|
||||
|
||||
use crate::providers::traits::{ChatMessage, Provider};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use ring::hmac;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Mutex;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
pub struct GlmProvider {
|
||||
api_key_id: String,
|
||||
api_key_secret: String,
|
||||
base_url: String,
|
||||
/// Cached JWT token + expiry timestamp (ms)
|
||||
token_cache: Mutex<Option<(String, u64)>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
temperature: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
content: String,
|
||||
}
|
||||
|
||||
/// Base64url encode without padding (per JWT spec).
|
||||
fn base64url_encode_bytes(data: &[u8]) -> String {
|
||||
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
|
||||
let mut result = String::new();
|
||||
let mut i = 0;
|
||||
while i < data.len() {
|
||||
let b0 = data[i] as u32;
|
||||
let b1 = if i + 1 < data.len() { data[i + 1] as u32 } else { 0 };
|
||||
let b2 = if i + 2 < data.len() { data[i + 2] as u32 } else { 0 };
|
||||
let triple = (b0 << 16) | (b1 << 8) | b2;
|
||||
|
||||
result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char);
|
||||
result.push(CHARS[((triple >> 12) & 0x3F) as usize] as char);
|
||||
|
||||
if i + 1 < data.len() {
|
||||
result.push(CHARS[((triple >> 6) & 0x3F) as usize] as char);
|
||||
}
|
||||
if i + 2 < data.len() {
|
||||
result.push(CHARS[(triple & 0x3F) as usize] as char);
|
||||
}
|
||||
|
||||
i += 3;
|
||||
}
|
||||
|
||||
// Convert to base64url: replace + with -, / with _, strip =
|
||||
result.replace('+', "-").replace('/', "_")
|
||||
}
|
||||
|
||||
fn base64url_encode_str(s: &str) -> String {
|
||||
base64url_encode_bytes(s.as_bytes())
|
||||
}
|
||||
|
||||
impl GlmProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
let (id, secret) = api_key
|
||||
.and_then(|k| k.split_once('.'))
|
||||
.map(|(id, secret)| (id.to_string(), secret.to_string()))
|
||||
.unwrap_or_default();
|
||||
|
||||
Self {
|
||||
api_key_id: id,
|
||||
api_key_secret: secret,
|
||||
base_url: "https://api.z.ai/api/paas/v4".to_string(),
|
||||
token_cache: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_token(&self) -> anyhow::Result<String> {
|
||||
if self.api_key_id.is_empty() || self.api_key_secret.is_empty() {
|
||||
anyhow::bail!(
|
||||
"GLM API key not set or invalid format. Expected 'id.secret'. \
|
||||
Run `zeroclaw onboard` or set GLM_API_KEY env var."
|
||||
);
|
||||
}
|
||||
|
||||
let now_ms = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)?
|
||||
.as_millis() as u64;
|
||||
|
||||
// Check cache (valid for 3 minutes, token expires at 3.5 min)
|
||||
if let Ok(cache) = self.token_cache.lock() {
|
||||
if let Some((ref token, expiry)) = *cache {
|
||||
if now_ms < expiry {
|
||||
return Ok(token.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let exp_ms = now_ms + 210_000; // 3.5 minutes
|
||||
|
||||
// Build JWT manually to include custom sign_type header
|
||||
// Header: {"alg":"HS256","typ":"JWT","sign_type":"SIGN"}
|
||||
let header_json = r#"{"alg":"HS256","typ":"JWT","sign_type":"SIGN"}"#;
|
||||
let header_b64 = base64url_encode_str(header_json);
|
||||
|
||||
// Payload: {"api_key":"...","exp":...,"timestamp":...}
|
||||
let payload_json = format!(
|
||||
r#"{{"api_key":"{}","exp":{},"timestamp":{}}}"#,
|
||||
self.api_key_id, exp_ms, now_ms
|
||||
);
|
||||
let payload_b64 = base64url_encode_str(&payload_json);
|
||||
|
||||
// Sign: HMAC-SHA256(header.payload, secret)
|
||||
let signing_input = format!("{header_b64}.{payload_b64}");
|
||||
let key = hmac::Key::new(hmac::HMAC_SHA256, self.api_key_secret.as_bytes());
|
||||
let signature = hmac::sign(&key, signing_input.as_bytes());
|
||||
let sig_b64 = base64url_encode_bytes(signature.as_ref());
|
||||
|
||||
let token = format!("{signing_input}.{sig_b64}");
|
||||
|
||||
// Cache for 3 minutes
|
||||
if let Ok(mut cache) = self.token_cache.lock() {
|
||||
*cache = Some((token.clone(), now_ms + 180_000));
|
||||
}
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
fn http_client(&self) -> Client {
|
||||
crate::config::build_runtime_proxy_client_with_timeouts("provider.glm", 120, 10)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for GlmProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let token = self.generate_token()?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(sys) = system_prompt {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: sys.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: message.to_string(),
|
||||
});
|
||||
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
temperature,
|
||||
};
|
||||
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
|
||||
let response = self
|
||||
.http_client()
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error = response.text().await?;
|
||||
anyhow::bail!("GLM API error: {error}");
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
|
||||
chat_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from GLM"))
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let token = self.generate_token()?;
|
||||
|
||||
let api_messages: Vec<Message> = messages
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
role: m.role.clone(),
|
||||
content: m.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages: api_messages,
|
||||
temperature,
|
||||
};
|
||||
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error = response.text().await?;
|
||||
anyhow::bail!("GLM API error: {error}");
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
|
||||
chat_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from GLM"))
|
||||
}
|
||||
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
if self.api_key_id.is_empty() || self.api_key_secret.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Generate and cache a JWT token, establishing TLS to the GLM API.
|
||||
let token = self.generate_token()?;
|
||||
let url = format!("{}/chat/completions", self.base_url);
|
||||
// GET will likely return 405 but establishes the TLS + HTTP/2 connection pool.
|
||||
let _ = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.send()
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parses_api_key() {
|
||||
let p = GlmProvider::new(Some("abc123.secretXYZ"));
|
||||
assert_eq!(p.api_key_id, "abc123");
|
||||
assert_eq!(p.api_key_secret, "secretXYZ");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handles_no_key() {
|
||||
let p = GlmProvider::new(None);
|
||||
assert!(p.api_key_id.is_empty());
|
||||
assert!(p.api_key_secret.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handles_invalid_key_format() {
|
||||
let p = GlmProvider::new(Some("no-dot-here"));
|
||||
assert!(p.api_key_id.is_empty());
|
||||
assert!(p.api_key_secret.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generates_jwt_token() {
|
||||
let p = GlmProvider::new(Some("testid.testsecret"));
|
||||
let token = p.generate_token().unwrap();
|
||||
assert!(!token.is_empty());
|
||||
// JWT has 3 dot-separated parts
|
||||
let parts: Vec<&str> = token.split('.').collect();
|
||||
assert_eq!(parts.len(), 3, "JWT should have 3 parts: {token}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn caches_token() {
|
||||
let p = GlmProvider::new(Some("testid.testsecret"));
|
||||
let token1 = p.generate_token().unwrap();
|
||||
let token2 = p.generate_token().unwrap();
|
||||
assert_eq!(token1, token2, "Cached token should be reused");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fails_without_key() {
|
||||
let p = GlmProvider::new(None);
|
||||
let result = p.generate_token();
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("API key not set"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_fails_without_key() {
|
||||
let p = GlmProvider::new(None);
|
||||
let result = p
|
||||
.chat_with_system(None, "hello", "glm-4.7", 0.7)
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_fails_without_key() {
|
||||
let p = GlmProvider::new(None);
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful."),
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there!"),
|
||||
ChatMessage::user("What did I say?"),
|
||||
];
|
||||
let result = p.chat_with_history(&messages, "glm-4.7", 0.7).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn base64url_no_padding() {
|
||||
let encoded = base64url_encode_bytes(b"hello");
|
||||
assert!(!encoded.contains('='));
|
||||
assert!(!encoded.contains('+'));
|
||||
assert!(!encoded.contains('/'));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn warmup_without_key_is_noop() {
|
||||
let provider = GlmProvider::new(None);
|
||||
let result = provider.warmup().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
326
third_party/zeroclaw/src/providers/kilocli.rs
vendored
Normal file
326
third_party/zeroclaw/src/providers/kilocli.rs
vendored
Normal file
@@ -0,0 +1,326 @@
|
||||
//! KiloCLI subprocess provider.
|
||||
//!
|
||||
//! Integrates with the KiloCLI tool, spawning the `kilo` binary
|
||||
//! as a subprocess for each inference request. This allows using KiloCLI's AI
|
||||
//! models without an interactive UI session.
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! The `kilo` binary must be available in `PATH`, or its location must be
|
||||
//! set via the `KILO_CLI_PATH` environment variable.
|
||||
//!
|
||||
//! KiloCLI is invoked as:
|
||||
//! ```text
|
||||
//! kilo --print -
|
||||
//! ```
|
||||
//! with prompt content written to stdin.
|
||||
//!
|
||||
//! # Limitations
|
||||
//!
|
||||
//! - **Conversation history**: Only the system prompt (if present) and the last
|
||||
//! user message are forwarded. Full multi-turn history is not preserved because
|
||||
//! the CLI accepts a single prompt per invocation.
|
||||
//! - **System prompt**: The system prompt is prepended to the user message with a
|
||||
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
|
||||
//! - **Temperature**: The CLI does not expose a temperature parameter.
|
||||
//! Only default values are accepted; custom values return an explicit error.
|
||||
//!
|
||||
//! # Authentication
|
||||
//!
|
||||
//! Authentication is handled by KiloCLI itself (its own credential store).
|
||||
//! No explicit API key is required by this provider.
|
||||
//!
|
||||
//! # Environment variables
|
||||
//!
|
||||
//! - `KILO_CLI_PATH` — override the path to the `kilo` binary (default: `"kilo"`)
|
||||
|
||||
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
|
||||
use async_trait::async_trait;
|
||||
use std::path::PathBuf;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
/// Environment variable for overriding the path to the `kilo` binary.
|
||||
pub const KILO_CLI_PATH_ENV: &str = "KILO_CLI_PATH";
|
||||
|
||||
/// Default `kilo` binary name (resolved via `PATH`).
|
||||
const DEFAULT_KILO_CLI_BINARY: &str = "kilo";
|
||||
|
||||
/// Model name used to signal "use the provider's own default model".
|
||||
const DEFAULT_MODEL_MARKER: &str = "default";
|
||||
/// KiloCLI requests are bounded to avoid hung subprocesses.
|
||||
const KILO_CLI_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
|
||||
/// Avoid leaking oversized stderr payloads.
|
||||
const MAX_KILO_CLI_STDERR_CHARS: usize = 512;
|
||||
/// The CLI does not support sampling controls; allow only baseline defaults.
|
||||
const KILO_CLI_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
|
||||
const TEMP_EPSILON: f64 = 1e-9;
|
||||
|
||||
/// Provider that invokes the KiloCLI as a subprocess.
|
||||
///
|
||||
/// Each inference request spawns a fresh `kilo` process. This is the
|
||||
/// non-interactive approach: the process handles the prompt and exits.
|
||||
pub struct KiloCliProvider {
|
||||
/// Path to the `kilo` binary.
|
||||
binary_path: PathBuf,
|
||||
}
|
||||
|
||||
impl KiloCliProvider {
|
||||
/// Create a new `KiloCliProvider`.
|
||||
///
|
||||
/// The binary path is resolved from `KILO_CLI_PATH` env var if set,
|
||||
/// otherwise defaults to `"kilo"` (found via `PATH`).
|
||||
pub fn new() -> Self {
|
||||
let binary_path = std::env::var(KILO_CLI_PATH_ENV)
|
||||
.ok()
|
||||
.filter(|path| !path.trim().is_empty())
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from(DEFAULT_KILO_CLI_BINARY));
|
||||
|
||||
Self { binary_path }
|
||||
}
|
||||
|
||||
/// Returns true if the model argument should be forwarded to the CLI.
|
||||
fn should_forward_model(model: &str) -> bool {
|
||||
let trimmed = model.trim();
|
||||
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
|
||||
}
|
||||
|
||||
fn supports_temperature(temperature: f64) -> bool {
|
||||
KILO_CLI_SUPPORTED_TEMPERATURES
|
||||
.iter()
|
||||
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
|
||||
}
|
||||
|
||||
fn validate_temperature(temperature: f64) -> anyhow::Result<()> {
|
||||
if !temperature.is_finite() {
|
||||
anyhow::bail!("KiloCLI provider received non-finite temperature value");
|
||||
}
|
||||
if !Self::supports_temperature(temperature) {
|
||||
anyhow::bail!(
|
||||
"temperature unsupported by KiloCLI: {temperature}. \
|
||||
Supported values: 0.7 or 1.0"
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn redact_stderr(stderr: &[u8]) -> String {
|
||||
let text = String::from_utf8_lossy(stderr);
|
||||
let trimmed = text.trim();
|
||||
if trimmed.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
if trimmed.chars().count() <= MAX_KILO_CLI_STDERR_CHARS {
|
||||
return trimmed.to_string();
|
||||
}
|
||||
let clipped: String = trimmed.chars().take(MAX_KILO_CLI_STDERR_CHARS).collect();
|
||||
format!("{clipped}...")
|
||||
}
|
||||
|
||||
/// Invoke the kilo binary with the given prompt and optional model.
|
||||
/// Returns the trimmed stdout output as the assistant response.
|
||||
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
|
||||
let mut cmd = Command::new(&self.binary_path);
|
||||
cmd.arg("--print");
|
||||
|
||||
if Self::should_forward_model(model) {
|
||||
cmd.arg("--model").arg(model);
|
||||
}
|
||||
|
||||
// Read prompt from stdin to avoid exposing sensitive content in process args.
|
||||
cmd.arg("-");
|
||||
cmd.kill_on_drop(true);
|
||||
cmd.stdin(std::process::Stdio::piped());
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
|
||||
let mut child = cmd.spawn().map_err(|err| {
|
||||
anyhow::anyhow!(
|
||||
"Failed to spawn KiloCLI binary at {}: {err}. \
|
||||
Ensure `kilo` is installed and in PATH, or set KILO_CLI_PATH.",
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
stdin
|
||||
.write_all(message.as_bytes())
|
||||
.await
|
||||
.map_err(|err| anyhow::anyhow!("Failed to write prompt to KiloCLI stdin: {err}"))?;
|
||||
stdin
|
||||
.shutdown()
|
||||
.await
|
||||
.map_err(|err| anyhow::anyhow!("Failed to finalize KiloCLI stdin stream: {err}"))?;
|
||||
}
|
||||
|
||||
let output = timeout(KILO_CLI_REQUEST_TIMEOUT, child.wait_with_output())
|
||||
.await
|
||||
.map_err(|_| {
|
||||
anyhow::anyhow!(
|
||||
"KiloCLI request timed out after {:?} (binary: {})",
|
||||
KILO_CLI_REQUEST_TIMEOUT,
|
||||
self.binary_path.display()
|
||||
)
|
||||
})?
|
||||
.map_err(|err| anyhow::anyhow!("KiloCLI process failed: {err}"))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let code = output.status.code().unwrap_or(-1);
|
||||
let stderr_excerpt = Self::redact_stderr(&output.stderr);
|
||||
let stderr_note = if stderr_excerpt.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" Stderr: {stderr_excerpt}")
|
||||
};
|
||||
anyhow::bail!(
|
||||
"KiloCLI exited with non-zero status {code}. \
|
||||
Check that KiloCLI is authenticated and the CLI is supported.{stderr_note}"
|
||||
);
|
||||
}
|
||||
|
||||
let text = String::from_utf8(output.stdout)
|
||||
.map_err(|err| anyhow::anyhow!("KiloCLI produced non-UTF-8 output: {err}"))?;
|
||||
|
||||
Ok(text.trim().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for KiloCliProvider {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for KiloCliProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Self::validate_temperature(temperature)?;
|
||||
|
||||
let full_message = match system_prompt {
|
||||
Some(system) if !system.is_empty() => {
|
||||
format!("{system}\n\n{message}")
|
||||
}
|
||||
_ => message.to_string(),
|
||||
};
|
||||
|
||||
self.invoke_cli(&full_message, model).await
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let text = self
|
||||
.chat_with_history(request.messages, model, temperature)
|
||||
.await?;
|
||||
|
||||
Ok(ChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
usage: Some(TokenUsage::default()),
|
||||
reasoning_content: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
.lock()
|
||||
.expect("env lock poisoned")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
|
||||
std::env::set_var(KILO_CLI_PATH_ENV, "/usr/local/bin/kilo");
|
||||
let provider = KiloCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/kilo"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(KILO_CLI_PATH_ENV, v),
|
||||
None => std::env::remove_var(KILO_CLI_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_defaults_to_kilo() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
|
||||
std::env::remove_var(KILO_CLI_PATH_ENV);
|
||||
let provider = KiloCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("kilo"));
|
||||
if let Some(v) = orig {
|
||||
std::env::set_var(KILO_CLI_PATH_ENV, v);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_ignores_blank_env_override() {
|
||||
let _guard = env_lock();
|
||||
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
|
||||
std::env::set_var(KILO_CLI_PATH_ENV, " ");
|
||||
let provider = KiloCliProvider::new();
|
||||
assert_eq!(provider.binary_path, PathBuf::from("kilo"));
|
||||
match orig {
|
||||
Some(v) => std::env::set_var(KILO_CLI_PATH_ENV, v),
|
||||
None => std::env::remove_var(KILO_CLI_PATH_ENV),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_forward_model_standard() {
|
||||
assert!(KiloCliProvider::should_forward_model("some-model"));
|
||||
assert!(KiloCliProvider::should_forward_model("gpt-4o"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_not_forward_default_model() {
|
||||
assert!(!KiloCliProvider::should_forward_model(DEFAULT_MODEL_MARKER));
|
||||
assert!(!KiloCliProvider::should_forward_model(""));
|
||||
assert!(!KiloCliProvider::should_forward_model(" "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_allows_defaults() {
|
||||
assert!(KiloCliProvider::validate_temperature(0.7).is_ok());
|
||||
assert!(KiloCliProvider::validate_temperature(1.0).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_temperature_rejects_custom_value() {
|
||||
let err = KiloCliProvider::validate_temperature(0.2).unwrap_err();
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("temperature unsupported by KiloCLI"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invoke_missing_binary_returns_error() {
|
||||
let provider = KiloCliProvider {
|
||||
binary_path: PathBuf::from("/nonexistent/path/to/kilo"),
|
||||
};
|
||||
let result = provider.invoke_cli("hello", "default").await;
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
msg.contains("Failed to spawn KiloCLI binary"),
|
||||
"unexpected error message: {msg}"
|
||||
);
|
||||
}
|
||||
}
|
||||
3625
third_party/zeroclaw/src/providers/mod.rs
vendored
Normal file
3625
third_party/zeroclaw/src/providers/mod.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1381
third_party/zeroclaw/src/providers/ollama.rs
vendored
Normal file
1381
third_party/zeroclaw/src/providers/ollama.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1017
third_party/zeroclaw/src/providers/openai.rs
vendored
Normal file
1017
third_party/zeroclaw/src/providers/openai.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1177
third_party/zeroclaw/src/providers/openai_codex.rs
vendored
Normal file
1177
third_party/zeroclaw/src/providers/openai_codex.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1237
third_party/zeroclaw/src/providers/openrouter.rs
vendored
Normal file
1237
third_party/zeroclaw/src/providers/openrouter.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
2982
third_party/zeroclaw/src/providers/reliable.rs
vendored
Normal file
2982
third_party/zeroclaw/src/providers/reliable.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
1179
third_party/zeroclaw/src/providers/router.rs
vendored
Normal file
1179
third_party/zeroclaw/src/providers/router.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
391
third_party/zeroclaw/src/providers/telnyx.rs
vendored
Normal file
391
third_party/zeroclaw/src/providers/telnyx.rs
vendored
Normal file
@@ -0,0 +1,391 @@
|
||||
//! Telnyx AI inference provider.
|
||||
//!
|
||||
//! Telnyx provides AI inference through an OpenAI-compatible API at
|
||||
//! https://api.telnyx.com/v2/ai with access to 53+ models including
|
||||
//! GPT-4o, Claude, Llama, Mistral, and more.
|
||||
//!
|
||||
//! # Configuration
|
||||
//!
|
||||
//! Set the `TELNYX_API_KEY` environment variable or configure in `config.toml`:
|
||||
//!
|
||||
//! ```toml
|
||||
//! default_provider = "telnyx"
|
||||
//! default_model = "openai/gpt-4o"
|
||||
//! ```
|
||||
|
||||
use crate::providers::traits::{ChatMessage, Provider};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
|
||||
/// Telnyx AI inference provider.
|
||||
///
|
||||
/// Uses the OpenAI-compatible chat completions API at `/v2/ai/chat/completions`.
|
||||
/// Supports 53+ models including OpenAI, Anthropic (via API), Meta Llama,
|
||||
/// Mistral, and more.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use zeroclaw::providers::telnyx::TelnyxProvider;
|
||||
/// use zeroclaw::providers::Provider;
|
||||
///
|
||||
/// let provider = TelnyxProvider::new(Some("your-api-key"));
|
||||
/// let response = provider.chat("Hello!", "openai/gpt-4o", 0.7).await?;
|
||||
/// ```
|
||||
pub struct TelnyxProvider {
|
||||
/// Telnyx API key
|
||||
api_key: Option<String>,
|
||||
/// HTTP client for API requests
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl TelnyxProvider {
|
||||
/// Telnyx AI API base URL
|
||||
const BASE_URL: &'static str = "https://api.telnyx.com/v2/ai";
|
||||
|
||||
/// Create a new Telnyx AI provider.
|
||||
///
|
||||
/// The API key can be provided directly or will be resolved from:
|
||||
/// 1. `TELNYX_API_KEY` environment variable
|
||||
/// 2. `ZEROCLAW_API_KEY` environment variable (fallback)
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
let resolved_key = resolve_telnyx_api_key(api_key);
|
||||
Self {
|
||||
api_key: resolved_key,
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a provider with a custom base URL (for testing or proxies).
|
||||
pub fn with_base_url(api_key: Option<&str>, _base_url: &str) -> Self {
|
||||
// Note: custom base URL support for testing
|
||||
Self::new(api_key)
|
||||
}
|
||||
|
||||
/// List available models from Telnyx AI.
|
||||
///
|
||||
/// Returns a list of model IDs that can be used with the chat API.
|
||||
pub async fn list_models(&self) -> anyhow::Result<Vec<String>> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!("Telnyx API key not set. Set TELNYX_API_KEY environment variable.")
|
||||
})?;
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(format!("{}/models", Self::BASE_URL))
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error = response.text().await?;
|
||||
anyhow::bail!("Failed to list Telnyx models: {}", error);
|
||||
}
|
||||
|
||||
let models_response: ModelsResponse = response.json().await?;
|
||||
Ok(models_response.data.into_iter().map(|m| m.id).collect())
|
||||
}
|
||||
|
||||
/// Build the chat completions URL
|
||||
fn chat_url(&self) -> String {
|
||||
format!("{}/chat/completions", Self::BASE_URL)
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve Telnyx API key from parameter or environment.
|
||||
fn resolve_telnyx_api_key(api_key: Option<&str>) -> Option<String> {
|
||||
if let Some(key) = api_key.map(str::trim).filter(|k| !k.is_empty()) {
|
||||
return Some(key.to_string());
|
||||
}
|
||||
|
||||
// Try Telnyx-specific env var first
|
||||
if let Ok(key) = std::env::var("TELNYX_API_KEY") {
|
||||
let key = key.trim();
|
||||
if !key.is_empty() {
|
||||
return Some(key.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to generic env vars
|
||||
for env_var in ["ZEROCLAW_API_KEY", "API_KEY"] {
|
||||
if let Ok(key) = std::env::var(env_var) {
|
||||
let key = key.trim();
|
||||
if !key.is_empty() {
|
||||
return Some(key.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Response from the /models endpoint
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ModelsResponse {
|
||||
data: Vec<ModelInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ModelInfo {
|
||||
id: String,
|
||||
}
|
||||
|
||||
/// Request body for chat completions
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
temperature: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
struct Message {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
/// Response from chat completions API
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for TelnyxProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Telnyx API key not set. Set TELNYX_API_KEY environment variable or run `zeroclaw onboard`."
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(sys) = system_prompt {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: sys.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: message.to_string(),
|
||||
});
|
||||
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
temperature,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(self.chat_url())
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error = response.text().await?;
|
||||
let sanitized = super::sanitize_api_error(&error);
|
||||
anyhow::bail!("Telnyx API error ({}): {}", status, sanitized);
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
|
||||
chat_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Telnyx"))
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Telnyx API key not set. Set TELNYX_API_KEY environment variable or run `zeroclaw onboard`."
|
||||
)
|
||||
})?;
|
||||
|
||||
let api_messages: Vec<Message> = messages
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
role: m.role.clone(),
|
||||
content: m.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages: api_messages,
|
||||
temperature,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(self.chat_url())
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error = response.text().await?;
|
||||
let sanitized = super::sanitize_api_error(&error);
|
||||
anyhow::bail!("Telnyx API error ({}): {}", status, sanitized);
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
|
||||
chat_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Telnyx"))
|
||||
}
|
||||
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
// Pre-warm the connection pool
|
||||
let _ = self
|
||||
.client
|
||||
.get(format!("{}/models", Self::BASE_URL))
|
||||
.send()
|
||||
.await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Popular Telnyx AI models for easy reference.
|
||||
pub mod models {
|
||||
/// OpenAI GPT-4o (recommended for most tasks)
|
||||
pub const GPT_4O: &str = "openai/gpt-4o";
|
||||
/// OpenAI GPT-4o Mini (fast and cost-effective)
|
||||
pub const GPT_4O_MINI: &str = "openai/gpt-4o-mini";
|
||||
/// OpenAI GPT-4 Turbo
|
||||
pub const GPT_4_TURBO: &str = "openai/gpt-4-turbo";
|
||||
/// Anthropic Claude 3.5 Sonnet (via Telnyx proxy)
|
||||
pub const CLAUDE_3_5_SONNET: &str = "anthropic/claude-3.5-sonnet";
|
||||
/// Meta Llama 3.1 70B Instruct
|
||||
pub const LLAMA_3_1_70B: &str = "meta-llama/llama-3.1-70b-instruct";
|
||||
/// Meta Llama 3.1 8B Instruct (fast)
|
||||
pub const LLAMA_3_1_8B: &str = "meta-llama/llama-3.1-8b-instruct";
|
||||
/// Mistral Large
|
||||
pub const MISTRAL_LARGE: &str = "mistralai/mistral-large";
|
||||
/// Mistral Small (fast)
|
||||
pub const MISTRAL_SMALL: &str = "mistralai/mistral-small";
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn creates_provider_with_key() {
|
||||
let provider = TelnyxProvider::new(Some("test-key"));
|
||||
assert!(provider.api_key.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_provider_without_key() {
|
||||
let _provider = TelnyxProvider::new(None);
|
||||
// Will be None if env vars not set
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_constants_are_valid() {
|
||||
assert!(models::GPT_4O.starts_with("openai/"));
|
||||
assert!(models::CLAUDE_3_5_SONNET.starts_with("anthropic/"));
|
||||
assert!(models::LLAMA_3_1_70B.starts_with("meta-llama/"));
|
||||
assert!(models::MISTRAL_LARGE.starts_with("mistralai/"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_key_from_parameter() {
|
||||
let key = resolve_telnyx_api_key(Some("direct-key"));
|
||||
assert_eq!(key, Some("direct-key".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_key_trims_whitespace() {
|
||||
let key = resolve_telnyx_api_key(Some(" spaced-key "));
|
||||
assert_eq!(key, Some("spaced-key".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn models_response_deserializes() {
|
||||
let json = r#"{
|
||||
"data": [
|
||||
{"id": "openai/gpt-4o"},
|
||||
{"id": "anthropic/claude-3.5-sonnet"}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let response: ModelsResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(response.data.len(), 2);
|
||||
assert_eq!(response.data[0].id, "openai/gpt-4o");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_request_serializes() {
|
||||
let req = ChatRequest {
|
||||
model: "openai/gpt-4o".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: "You are helpful.".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "Hello".to_string(),
|
||||
},
|
||||
],
|
||||
temperature: 0.7,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("openai/gpt-4o"));
|
||||
assert!(json.contains("system"));
|
||||
assert!(json.contains("user"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_response_deserializes() {
|
||||
let json = r#"{"choices":[{"message":{"content":"Hello from Telnyx!"}}]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.choices[0].message.content, "Hello from Telnyx!");
|
||||
}
|
||||
}
|
||||
1092
third_party/zeroclaw/src/providers/traits.rs
vendored
Normal file
1092
third_party/zeroclaw/src/providers/traits.rs
vendored
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user