feat: refactor sgclaw around zeroclaw compat runtime

This commit is contained in:
zyl
2026-03-26 16:23:31 +08:00
parent bca5b75801
commit ff0771a83f
1059 changed files with 409460 additions and 23 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,759 @@
use crate::providers::traits::{
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
Provider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, ToolsPayload,
};
use crate::tools::ToolSpec;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
const DEFAULT_API_VERSION: &str = "2024-08-01-preview";
pub struct AzureOpenAiProvider {
credential: Option<String>,
resource_name: String,
deployment_name: String,
api_version: String,
base_url: String,
}
#[derive(Debug, Serialize)]
struct ChatRequest {
messages: Vec<Message>,
temperature: f64,
}
#[derive(Debug, Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
}
impl ResponseMessage {
fn effective_content(&self) -> String {
match &self.content {
Some(c) if !c.is_empty() => c.clone(),
_ => self.reasoning_content.clone().unwrap_or_default(),
}
}
}
#[derive(Debug, Serialize)]
struct NativeChatRequest {
messages: Vec<NativeMessage>,
temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<NativeToolSpec>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
}
#[derive(Debug, Serialize)]
struct NativeMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<NativeToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeToolSpec {
#[serde(rename = "type")]
kind: String,
function: NativeToolFunctionSpec,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeToolFunctionSpec {
name: String,
description: String,
parameters: serde_json::Value,
}
fn parse_native_tool_spec(value: serde_json::Value) -> anyhow::Result<NativeToolSpec> {
let spec: NativeToolSpec = serde_json::from_value(value)
.map_err(|e| anyhow::anyhow!("Invalid Azure OpenAI tool specification: {e}"))?;
if spec.kind != "function" {
anyhow::bail!(
"Invalid Azure OpenAI tool specification: unsupported tool type '{}', expected 'function'",
spec.kind
);
}
Ok(spec)
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
kind: Option<String>,
function: NativeFunctionCall,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeFunctionCall {
name: String,
arguments: String,
}
#[derive(Debug, Deserialize)]
struct NativeChatResponse {
choices: Vec<NativeChoice>,
#[serde(default)]
usage: Option<UsageInfo>,
}
#[derive(Debug, Deserialize)]
struct UsageInfo {
#[serde(default)]
prompt_tokens: Option<u64>,
#[serde(default)]
completion_tokens: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct NativeChoice {
message: NativeResponseMessage,
}
#[derive(Debug, Deserialize)]
struct NativeResponseMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<NativeToolCall>>,
}
impl NativeResponseMessage {
fn effective_content(&self) -> Option<String> {
match &self.content {
Some(c) if !c.is_empty() => Some(c.clone()),
_ => self.reasoning_content.clone(),
}
}
}
impl AzureOpenAiProvider {
pub fn new(
credential: Option<&str>,
resource_name: &str,
deployment_name: &str,
api_version: Option<&str>,
) -> Self {
let version = api_version.unwrap_or(DEFAULT_API_VERSION);
let base_url = format!(
"https://{}.openai.azure.com/openai/deployments/{}",
resource_name, deployment_name
);
Self {
credential: credential.map(ToString::to_string),
resource_name: resource_name.to_string(),
deployment_name: deployment_name.to_string(),
api_version: version.to_string(),
base_url,
}
}
fn chat_completions_url(&self) -> String {
format!(
"{}/chat/completions?api-version={}",
self.base_url, self.api_version
)
}
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
tools.map(|items| {
items
.iter()
.map(|tool| NativeToolSpec {
kind: "function".to_string(),
function: NativeToolFunctionSpec {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.parameters.clone(),
},
})
.collect()
})
}
fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
messages
.iter()
.map(|m| {
if m.role == "assistant" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
if let Some(tool_calls_value) = value.get("tool_calls") {
if let Ok(parsed_calls) =
serde_json::from_value::<Vec<ProviderToolCall>>(
tool_calls_value.clone(),
)
{
let tool_calls = parsed_calls
.into_iter()
.map(|tc| NativeToolCall {
id: Some(tc.id),
kind: Some("function".to_string()),
function: NativeFunctionCall {
name: tc.name,
arguments: tc.arguments,
},
})
.collect::<Vec<_>>();
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
let reasoning_content = value
.get("reasoning_content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return NativeMessage {
role: "assistant".to_string(),
content,
tool_call_id: None,
tool_calls: Some(tool_calls),
reasoning_content,
};
}
}
}
}
if m.role == "tool" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
let tool_call_id = value
.get("tool_call_id")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return NativeMessage {
role: "tool".to_string(),
content,
tool_call_id,
tool_calls: None,
reasoning_content: None,
};
}
}
NativeMessage {
role: m.role.clone(),
content: Some(m.content.clone()),
tool_call_id: None,
tool_calls: None,
reasoning_content: None,
}
})
.collect()
}
fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse {
let text = message.effective_content();
let reasoning_content = message.reasoning_content.clone();
let tool_calls = message
.tool_calls
.unwrap_or_default()
.into_iter()
.map(|tc| ProviderToolCall {
id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
name: tc.function.name,
arguments: tc.function.arguments,
})
.collect::<Vec<_>>();
ProviderChatResponse {
text,
tool_calls,
usage: None,
reasoning_content,
}
}
fn http_client(&self) -> Client {
crate::config::build_runtime_proxy_client_with_timeouts("provider.azure_openai", 120, 10)
}
}
#[async_trait]
impl Provider for AzureOpenAiProvider {
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
native_tool_calling: true,
vision: true,
prompt_caching: false,
}
}
fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
ToolsPayload::OpenAI {
tools: tools
.iter()
.map(|tool| {
serde_json::json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
}
})
})
.collect(),
}
}
fn supports_native_tools(&self) -> bool {
true
}
fn supports_vision(&self) -> bool {
true
}
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
_model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
)
})?;
let mut messages = Vec::new();
if let Some(sys) = system_prompt {
messages.push(Message {
role: "system".to_string(),
content: sys.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: message.to_string(),
});
let request = ChatRequest {
messages,
temperature,
};
let response = self
.http_client()
.post(self.chat_completions_url())
.header("api-key", credential.as_str())
.json(&request)
.send()
.await?;
if !response.status().is_success() {
return Err(super::api_error("Azure OpenAI", response).await);
}
let chat_response: ChatResponse = response.json().await?;
chat_response
.choices
.into_iter()
.next()
.map(|c| c.message.effective_content())
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))
}
async fn chat(
&self,
request: ProviderChatRequest<'_>,
_model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
)
})?;
let tools = Self::convert_tools(request.tools);
let native_request = NativeChatRequest {
messages: Self::convert_messages(request.messages),
temperature,
tool_choice: tools.as_ref().map(|_| "auto".to_string()),
tools,
};
let response = self
.http_client()
.post(self.chat_completions_url())
.header("api-key", credential.as_str())
.json(&native_request)
.send()
.await?;
if !response.status().is_success() {
return Err(super::api_error("Azure OpenAI", response).await);
}
let native_response: NativeChatResponse = response.json().await?;
let usage = native_response.usage.map(|u| TokenUsage {
input_tokens: u.prompt_tokens,
output_tokens: u.completion_tokens,
cached_input_tokens: None,
});
let message = native_response
.choices
.into_iter()
.next()
.map(|c| c.message)
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?;
let mut result = Self::parse_native_response(message);
result.usage = usage;
Ok(result)
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[serde_json::Value],
_model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"Azure OpenAI API key not set. Set AZURE_OPENAI_API_KEY or edit config.toml."
)
})?;
let native_tools: Option<Vec<NativeToolSpec>> = if tools.is_empty() {
None
} else {
Some(
tools
.iter()
.cloned()
.map(parse_native_tool_spec)
.collect::<Result<Vec<_>, _>>()?,
)
};
let native_request = NativeChatRequest {
messages: Self::convert_messages(messages),
temperature,
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
tools: native_tools,
};
let response = self
.http_client()
.post(self.chat_completions_url())
.header("api-key", credential.as_str())
.json(&native_request)
.send()
.await?;
if !response.status().is_success() {
return Err(super::api_error("Azure OpenAI", response).await);
}
let native_response: NativeChatResponse = response.json().await?;
let usage = native_response.usage.map(|u| TokenUsage {
input_tokens: u.prompt_tokens,
output_tokens: u.completion_tokens,
cached_input_tokens: None,
});
let message = native_response
.choices
.into_iter()
.next()
.map(|c| c.message)
.ok_or_else(|| anyhow::anyhow!("No response from Azure OpenAI"))?;
let mut result = Self::parse_native_response(message);
result.usage = usage;
Ok(result)
}
async fn warmup(&self) -> anyhow::Result<()> {
// Azure OpenAI does not have a lightweight models endpoint,
// so warmup is a no-op to avoid unnecessary API calls.
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn url_construction_default_version() {
let p = AzureOpenAiProvider::new(Some("test-key"), "my-resource", "gpt-4o", None);
assert_eq!(
p.chat_completions_url(),
"https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
);
}
#[test]
fn url_construction_custom_version() {
let p = AzureOpenAiProvider::new(
Some("test-key"),
"my-resource",
"gpt-4o",
Some("2024-06-01"),
);
assert_eq!(
p.chat_completions_url(),
"https://my-resource.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2024-06-01"
);
}
#[test]
fn url_construction_preserves_resource_and_deployment() {
let p = AzureOpenAiProvider::new(Some("key"), "contoso-ai", "my-gpt35-deployment", None);
let url = p.chat_completions_url();
assert!(url.contains("contoso-ai.openai.azure.com"));
assert!(url.contains("/deployments/my-gpt35-deployment/"));
assert!(url.contains("api-version=2024-08-01-preview"));
}
#[test]
fn auth_header_uses_api_key_not_bearer() {
// This test verifies the provider stores the credential correctly
// and that the auth header name is "api-key" (verified via the
// implementation in chat_with_system which uses .header("api-key", ...)).
let p = AzureOpenAiProvider::new(Some("my-azure-key"), "resource", "deployment", None);
assert_eq!(p.credential.as_deref(), Some("my-azure-key"));
}
#[test]
fn creates_with_credential() {
let p = AzureOpenAiProvider::new(
Some("azure-test-credential"),
"resource",
"deployment",
None,
);
assert_eq!(p.credential.as_deref(), Some("azure-test-credential"));
assert_eq!(p.resource_name, "resource");
assert_eq!(p.deployment_name, "deployment");
assert_eq!(p.api_version, DEFAULT_API_VERSION);
}
#[test]
fn creates_without_credential() {
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
assert!(p.credential.is_none());
}
#[tokio::test]
async fn chat_fails_without_key() {
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
let result = p.chat_with_system(None, "hello", "gpt-4o", 0.7).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("API key not set"));
}
#[tokio::test]
async fn chat_with_system_fails_without_key() {
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
let result = p
.chat_with_system(Some("You are ZeroClaw"), "test", "gpt-4o", 0.5)
.await;
assert!(result.is_err());
}
#[test]
fn request_serializes_with_system_message() {
let req = ChatRequest {
messages: vec![
Message {
role: "system".to_string(),
content: "You are ZeroClaw".to_string(),
},
Message {
role: "user".to_string(),
content: "hello".to_string(),
},
],
temperature: 0.7,
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"role\":\"system\""));
assert!(json.contains("\"role\":\"user\""));
// Azure requests should NOT contain a model field (deployment is in the URL)
assert!(!json.contains("\"model\""));
}
#[test]
fn request_serializes_without_system() {
let req = ChatRequest {
messages: vec![Message {
role: "user".to_string(),
content: "hello".to_string(),
}],
temperature: 0.0,
};
let json = serde_json::to_string(&req).unwrap();
assert!(!json.contains("system"));
assert!(json.contains("\"temperature\":0.0"));
}
#[test]
fn response_deserializes_single_choice() {
let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.choices.len(), 1);
assert_eq!(resp.choices[0].message.effective_content(), "Hi!");
}
#[test]
fn response_deserializes_empty_choices() {
let json = r#"{"choices":[]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.choices.is_empty());
}
#[test]
fn response_deserializes_multiple_choices() {
let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.choices.len(), 2);
assert_eq!(resp.choices[0].message.effective_content(), "A");
}
#[test]
fn tool_call_response_parsing() {
let json = r#"{"choices":[{"message":{
"content":"Let me check",
"tool_calls":[{
"id":"call_abc123",
"type":"function",
"function":{"name":"shell","arguments":"{\"command\":\"ls\"}"}
}]
}}],"usage":{"prompt_tokens":50,"completion_tokens":25}}"#;
let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
let message = resp.choices.into_iter().next().unwrap().message;
let parsed = AzureOpenAiProvider::parse_native_response(message);
assert_eq!(parsed.text.as_deref(), Some("Let me check"));
assert_eq!(parsed.tool_calls.len(), 1);
assert_eq!(parsed.tool_calls[0].id, "call_abc123");
assert_eq!(parsed.tool_calls[0].name, "shell");
assert!(parsed.tool_calls[0].arguments.contains("ls"));
}
#[test]
fn tool_call_response_without_id_generates_uuid() {
let json = r#"{"choices":[{"message":{
"content":null,
"tool_calls":[{
"function":{"name":"test","arguments":"{}"}
}]
}}]}"#;
let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
let message = resp.choices.into_iter().next().unwrap().message;
let parsed = AzureOpenAiProvider::parse_native_response(message);
assert_eq!(parsed.tool_calls.len(), 1);
assert!(!parsed.tool_calls[0].id.is_empty());
}
#[tokio::test]
async fn chat_with_tools_fails_without_key() {
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
let messages = vec![ChatMessage::user("hello".to_string())];
let tools = vec![serde_json::json!({
"type": "function",
"function": {
"name": "shell",
"description": "Run a shell command",
"parameters": {
"type": "object",
"properties": {
"command": { "type": "string" }
},
"required": ["command"]
}
}
})];
let result = p.chat_with_tools(&messages, &tools, "gpt-4o", 0.7).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("API key not set"));
}
#[test]
fn native_response_parses_usage() {
let json = r#"{
"choices": [{"message": {"content": "Hello"}}],
"usage": {"prompt_tokens": 100, "completion_tokens": 50}
}"#;
let resp: NativeChatResponse = serde_json::from_str(json).unwrap();
let usage = resp.usage.unwrap();
assert_eq!(usage.prompt_tokens, Some(100));
assert_eq!(usage.completion_tokens, Some(50));
}
#[test]
fn capabilities_reports_native_tools_and_vision() {
let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
let caps = <AzureOpenAiProvider as Provider>::capabilities(&p);
assert!(caps.native_tool_calling);
assert!(caps.vision);
}
#[test]
fn supports_native_tools_returns_true() {
let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
assert!(p.supports_native_tools());
}
#[test]
fn supports_vision_returns_true() {
let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", None);
assert!(p.supports_vision());
}
#[tokio::test]
async fn warmup_is_noop() {
let p = AzureOpenAiProvider::new(None, "resource", "deployment", None);
let result = p.warmup().await;
assert!(result.is_ok());
}
#[test]
fn custom_api_version_stored() {
let p = AzureOpenAiProvider::new(Some("key"), "resource", "deployment", Some("2025-01-01"));
assert_eq!(p.api_version, "2025-01-01");
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,523 @@
//! Claude Code headless CLI provider.
//!
//! Integrates with the Claude Code CLI, spawning the `claude` binary
//! as a subprocess for each inference request. This allows using Claude's AI
//! models without an interactive UI session.
//!
//! # Usage
//!
//! The `claude` binary must be available in `PATH`, or its location must be
//! set via the `CLAUDE_CODE_PATH` environment variable.
//!
//! Claude Code is invoked as:
//! ```text
//! claude --print -
//! ```
//! with prompt content written to stdin.
//!
//! # Limitations
//!
//! - **System prompt**: The system prompt is prepended to the user message with a
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
//! - **Temperature**: The CLI does not expose a temperature parameter.
//! Only default values are accepted; custom values return an explicit error.
//!
//! # Authentication
//!
//! Authentication is handled by Claude Code itself (its own credential store).
//! No explicit API key is required by this provider.
//!
//! # Environment variables
//!
//! - `CLAUDE_CODE_PATH` — override the path to the `claude` binary (default: `"claude"`)
use crate::providers::traits::{ChatMessage, ChatRequest, ChatResponse, Provider, TokenUsage};
use async_trait::async_trait;
use std::path::PathBuf;
use tokio::io::AsyncWriteExt;
use tokio::process::Command;
use tokio::time::{timeout, Duration};
/// Environment variable for overriding the path to the `claude` binary.
pub const CLAUDE_CODE_PATH_ENV: &str = "CLAUDE_CODE_PATH";
/// Default `claude` binary name (resolved via `PATH`).
const DEFAULT_CLAUDE_CODE_BINARY: &str = "claude";
/// Model name used to signal "use the provider's own default model".
const DEFAULT_MODEL_MARKER: &str = "default";
/// Claude Code requests are bounded to avoid hung subprocesses.
const CLAUDE_CODE_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
/// Avoid leaking oversized stderr payloads.
const MAX_CLAUDE_CODE_STDERR_CHARS: usize = 512;
/// The CLI does not support sampling controls; allow only baseline defaults.
const CLAUDE_CODE_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
const TEMP_EPSILON: f64 = 1e-9;
/// Provider that invokes the Claude Code CLI as a subprocess.
///
/// Each inference request spawns a fresh `claude` process. This is the
/// non-interactive approach: the process handles the prompt and exits.
pub struct ClaudeCodeProvider {
/// Path to the `claude` binary.
binary_path: PathBuf,
}
impl ClaudeCodeProvider {
/// Create a new `ClaudeCodeProvider`.
///
/// The binary path is resolved from `CLAUDE_CODE_PATH` env var if set,
/// otherwise defaults to `"claude"` (found via `PATH`).
pub fn new() -> Self {
let binary_path = std::env::var(CLAUDE_CODE_PATH_ENV)
.ok()
.filter(|path| !path.trim().is_empty())
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from(DEFAULT_CLAUDE_CODE_BINARY));
Self { binary_path }
}
/// Returns true if the model argument should be forwarded to the CLI.
fn should_forward_model(model: &str) -> bool {
let trimmed = model.trim();
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
}
fn supports_temperature(temperature: f64) -> bool {
CLAUDE_CODE_SUPPORTED_TEMPERATURES
.iter()
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
}
fn validate_temperature(temperature: f64) -> anyhow::Result<f64> {
if !temperature.is_finite() {
anyhow::bail!("Claude Code provider received non-finite temperature value");
}
if Self::supports_temperature(temperature) {
return Ok(temperature);
}
// Clamp to the nearest supported value — the CLI ignores temperature
// anyway, so a hard error just blocks callers like memory consolidation
// that legitimately request low temperatures.
let clamped = *CLAUDE_CODE_SUPPORTED_TEMPERATURES
.iter()
.min_by(|a, b| {
(temperature - **a)
.abs()
.partial_cmp(&(temperature - **b).abs())
.unwrap()
})
.unwrap();
tracing::debug!(
requested = temperature,
clamped = clamped,
"Clamped unsupported temperature to nearest Claude Code CLI value"
);
Ok(clamped)
}
fn redact_stderr(stderr: &[u8]) -> String {
let text = String::from_utf8_lossy(stderr);
let trimmed = text.trim();
if trimmed.is_empty() {
return String::new();
}
if trimmed.chars().count() <= MAX_CLAUDE_CODE_STDERR_CHARS {
return trimmed.to_string();
}
let clipped: String = trimmed.chars().take(MAX_CLAUDE_CODE_STDERR_CHARS).collect();
format!("{clipped}...")
}
/// Invoke the claude binary with the given prompt and optional model.
/// Returns the trimmed stdout output as the assistant response.
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
let mut cmd = Command::new(&self.binary_path);
cmd.arg("--print");
if Self::should_forward_model(model) {
cmd.arg("--model").arg(model);
}
// Read prompt from stdin to avoid exposing sensitive content in process args.
cmd.arg("-");
cmd.kill_on_drop(true);
cmd.stdin(std::process::Stdio::piped());
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
let mut child = cmd.spawn().map_err(|err| {
anyhow::anyhow!(
"Failed to spawn Claude Code binary at {}: {err}. \
Ensure `claude` is installed and in PATH, or set CLAUDE_CODE_PATH.",
self.binary_path.display()
)
})?;
if let Some(mut stdin) = child.stdin.take() {
stdin.write_all(message.as_bytes()).await.map_err(|err| {
anyhow::anyhow!("Failed to write prompt to Claude Code stdin: {err}")
})?;
stdin.shutdown().await.map_err(|err| {
anyhow::anyhow!("Failed to finalize Claude Code stdin stream: {err}")
})?;
}
let output = timeout(CLAUDE_CODE_REQUEST_TIMEOUT, child.wait_with_output())
.await
.map_err(|_| {
anyhow::anyhow!(
"Claude Code request timed out after {:?} (binary: {})",
CLAUDE_CODE_REQUEST_TIMEOUT,
self.binary_path.display()
)
})?
.map_err(|err| anyhow::anyhow!("Claude Code process failed: {err}"))?;
if !output.status.success() {
let code = output.status.code().unwrap_or(-1);
let stderr_excerpt = Self::redact_stderr(&output.stderr);
let stderr_note = if stderr_excerpt.is_empty() {
String::new()
} else {
format!(" Stderr: {stderr_excerpt}")
};
anyhow::bail!(
"Claude Code exited with non-zero status {code}. \
Check that Claude Code is authenticated and the CLI is supported.{stderr_note}"
);
}
let text = String::from_utf8(output.stdout)
.map_err(|err| anyhow::anyhow!("Claude Code produced non-UTF-8 output: {err}"))?;
Ok(text.trim().to_string())
}
}
impl Default for ClaudeCodeProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Provider for ClaudeCodeProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
Self::validate_temperature(temperature)?;
let full_message = match system_prompt {
Some(system) if !system.is_empty() => {
format!("{system}\n\n{message}")
}
_ => message.to_string(),
};
self.invoke_cli(&full_message, model).await
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
Self::validate_temperature(temperature)?;
// Separate system prompt from conversation messages.
let system = messages
.iter()
.find(|m| m.role == "system")
.map(|m| m.content.as_str());
// Build conversation turns (skip system messages).
let turns: Vec<&ChatMessage> = messages.iter().filter(|m| m.role != "system").collect();
// If there's only one user message, use the simple path.
if turns.len() <= 1 {
let last_user = turns.first().map(|m| m.content.as_str()).unwrap_or("");
let full_message = match system {
Some(s) if !s.is_empty() => format!("{s}\n\n{last_user}"),
_ => last_user.to_string(),
};
return self.invoke_cli(&full_message, model).await;
}
// Format multi-turn conversation into a single prompt.
let mut parts = Vec::new();
if let Some(s) = system {
if !s.is_empty() {
parts.push(format!("[system]\n{s}"));
}
}
for msg in &turns {
let label = match msg.role.as_str() {
"user" => "[user]",
"assistant" => "[assistant]",
other => other,
};
parts.push(format!("{label}\n{}", msg.content));
}
parts.push("[assistant]".to_string());
let full_message = parts.join("\n\n");
self.invoke_cli(&full_message, model).await
}
async fn chat(
&self,
request: ChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
let text = self
.chat_with_history(request.messages, model, temperature)
.await?;
Ok(ChatResponse {
text: Some(text),
tool_calls: Vec::new(),
usage: Some(TokenUsage::default()),
reasoning_content: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Mutex, OnceLock};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.expect("env lock poisoned")
}
/// Serialize tests that spawn the echo-provider script.
///
/// On Linux, writing a shell script and exec'ing it from parallel threads
/// can trigger `ETXTBSY` ("Text file busy") even with unique file paths,
/// because the kernel briefly holds `deny_write_access` on the interpreter
/// page cache. Serializing these tests eliminates the race.
///
/// Uses `tokio::sync::Mutex` so the guard can be held across `.await`.
fn script_mutex() -> &'static tokio::sync::Mutex<()> {
static LOCK: OnceLock<tokio::sync::Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| tokio::sync::Mutex::new(()))
}
#[test]
fn new_uses_env_override() {
let _guard = env_lock();
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
std::env::set_var(CLAUDE_CODE_PATH_ENV, "/usr/local/bin/claude");
let provider = ClaudeCodeProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/claude"));
match orig {
Some(v) => std::env::set_var(CLAUDE_CODE_PATH_ENV, v),
None => std::env::remove_var(CLAUDE_CODE_PATH_ENV),
}
}
#[test]
fn new_defaults_to_claude() {
let _guard = env_lock();
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
std::env::remove_var(CLAUDE_CODE_PATH_ENV);
let provider = ClaudeCodeProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("claude"));
if let Some(v) = orig {
std::env::set_var(CLAUDE_CODE_PATH_ENV, v);
}
}
#[test]
fn new_ignores_blank_env_override() {
let _guard = env_lock();
let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
std::env::set_var(CLAUDE_CODE_PATH_ENV, " ");
let provider = ClaudeCodeProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("claude"));
match orig {
Some(v) => std::env::set_var(CLAUDE_CODE_PATH_ENV, v),
None => std::env::remove_var(CLAUDE_CODE_PATH_ENV),
}
}
#[test]
fn should_forward_model_standard() {
assert!(ClaudeCodeProvider::should_forward_model(
"claude-sonnet-4-20250514"
));
assert!(ClaudeCodeProvider::should_forward_model(
"claude-3.5-sonnet"
));
}
#[test]
fn should_not_forward_default_model() {
assert!(!ClaudeCodeProvider::should_forward_model(
DEFAULT_MODEL_MARKER
));
assert!(!ClaudeCodeProvider::should_forward_model(""));
assert!(!ClaudeCodeProvider::should_forward_model(" "));
}
#[test]
fn validate_temperature_allows_defaults() {
assert!(ClaudeCodeProvider::validate_temperature(0.7).is_ok());
assert!(ClaudeCodeProvider::validate_temperature(1.0).is_ok());
}
#[test]
fn validate_temperature_clamps_custom_value() {
let clamped = ClaudeCodeProvider::validate_temperature(0.2).unwrap();
assert!((clamped - 0.7).abs() < 1e-9, "0.2 should clamp to 0.7");
let clamped = ClaudeCodeProvider::validate_temperature(0.9).unwrap();
assert!((clamped - 1.0).abs() < 1e-9, "0.9 should clamp to 1.0");
}
#[test]
fn validate_temperature_rejects_non_finite() {
assert!(ClaudeCodeProvider::validate_temperature(f64::NAN).is_err());
assert!(ClaudeCodeProvider::validate_temperature(f64::INFINITY).is_err());
}
#[tokio::test]
async fn invoke_missing_binary_returns_error() {
let provider = ClaudeCodeProvider {
binary_path: PathBuf::from("/nonexistent/path/to/claude"),
};
let result = provider.invoke_cli("hello", "default").await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("Failed to spawn Claude Code binary"),
"unexpected error message: {msg}"
);
}
/// Helper: create a provider that uses a shell script echoing stdin back.
/// The script ignores CLI flags (`--print`, `--model`, `-`) and just cats stdin.
///
/// Each invocation places the script in its own unique directory and writes
/// the file atomically via `std::fs::write` to avoid `ETXTBSY` ("Text file
/// busy") races that occur when parallel test threads create and exec
/// scripts concurrently on the same filesystem.
fn echo_provider() -> ClaudeCodeProvider {
static SCRIPT_ID: AtomicUsize = AtomicUsize::new(0);
let script_id = SCRIPT_ID.fetch_add(1, Ordering::Relaxed);
let dir = std::env::temp_dir().join(format!(
"zeroclaw_test_claude_code_{}_{}",
std::process::id(),
script_id
));
std::fs::create_dir_all(&dir).unwrap();
let path = dir.join("fake_claude.sh");
std::fs::write(&path, "#!/bin/sh\ncat /dev/stdin\n").unwrap();
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o755)).unwrap();
}
ClaudeCodeProvider { binary_path: path }
}
#[test]
fn echo_provider_uses_unique_script_paths() {
let first = echo_provider();
let second = echo_provider();
assert_ne!(first.binary_path, second.binary_path);
}
#[tokio::test]
async fn chat_with_history_single_user_message() {
let _lock = script_mutex().lock().await;
let provider = echo_provider();
let messages = vec![ChatMessage::user("hello")];
let result = provider
.chat_with_history(&messages, "default", 1.0)
.await
.unwrap();
assert_eq!(result, "hello");
}
#[tokio::test]
async fn chat_with_history_single_user_with_system() {
let _lock = script_mutex().lock().await;
let provider = echo_provider();
let messages = vec![
ChatMessage::system("You are helpful."),
ChatMessage::user("hello"),
];
let result = provider
.chat_with_history(&messages, "default", 1.0)
.await
.unwrap();
assert_eq!(result, "You are helpful.\n\nhello");
}
#[tokio::test]
async fn chat_with_history_multi_turn_includes_all_messages() {
let _lock = script_mutex().lock().await;
let provider = echo_provider();
let messages = vec![
ChatMessage::system("Be concise."),
ChatMessage::user("What is 2+2?"),
ChatMessage::assistant("4"),
ChatMessage::user("And 3+3?"),
];
let result = provider
.chat_with_history(&messages, "default", 1.0)
.await
.unwrap();
assert!(result.contains("[system]\nBe concise."));
assert!(result.contains("[user]\nWhat is 2+2?"));
assert!(result.contains("[assistant]\n4"));
assert!(result.contains("[user]\nAnd 3+3?"));
assert!(result.ends_with("[assistant]"));
}
#[tokio::test]
async fn chat_with_history_multi_turn_without_system() {
let _lock = script_mutex().lock().await;
let provider = echo_provider();
let messages = vec![
ChatMessage::user("hi"),
ChatMessage::assistant("hello"),
ChatMessage::user("bye"),
];
let result = provider
.chat_with_history(&messages, "default", 1.0)
.await
.unwrap();
assert!(!result.contains("[system]"));
assert!(result.contains("[user]\nhi"));
assert!(result.contains("[assistant]\nhello"));
assert!(result.contains("[user]\nbye"));
}
#[tokio::test]
async fn chat_with_history_clamps_bad_temperature() {
let _lock = script_mutex().lock().await;
let provider = echo_provider();
let messages = vec![ChatMessage::user("test")];
let result = provider.chat_with_history(&messages, "default", 0.5).await;
assert!(
result.is_ok(),
"unsupported temperature should be clamped, not rejected"
);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,822 @@
//! GitHub Copilot provider with OAuth device-flow authentication.
//!
//! Authenticates via GitHub's device code flow (same as VS Code Copilot),
//! then exchanges the OAuth token for short-lived Copilot API keys.
//! Tokens are cached to disk and auto-refreshed.
//!
//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and
//! editor headers. This is the same approach used by LiteLLM, Codex CLI,
//! and other third-party Copilot integrations. The Copilot token endpoint is
//! private; there is no public OAuth scope or app registration for it.
//! GitHub could change or revoke this at any time, which would break all
//! third-party integrations simultaneously.
use crate::providers::traits::{
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
Provider, TokenUsage, ToolCall as ProviderToolCall,
};
use crate::tools::ToolSpec;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tracing::warn;
/// GitHub OAuth client ID for Copilot (VS Code extension).
const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
const DEFAULT_API: &str = "https://api.githubcopilot.com";
// ── Token types ──────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default = "default_interval")]
interval: u64,
#[serde(default = "default_expires_in")]
expires_in: u64,
}
fn default_interval() -> u64 {
5
}
fn default_expires_in() -> u64 {
900
}
#[derive(Debug, Deserialize)]
struct AccessTokenResponse {
access_token: Option<String>,
error: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ApiKeyInfo {
token: String,
expires_at: i64,
#[serde(default)]
endpoints: Option<ApiEndpoints>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ApiEndpoints {
api: Option<String>,
}
struct CachedApiKey {
token: String,
api_endpoint: String,
expires_at: i64,
}
// ── Chat completions types ───────────────────────────────────────
#[derive(Debug, Serialize)]
struct ApiChatRequest<'a> {
model: String,
messages: Vec<ApiMessage>,
temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<NativeToolSpec<'a>>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
}
#[derive(Debug, Serialize)]
struct ApiMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<ApiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<NativeToolCall>>,
}
#[derive(Debug, Serialize)]
struct NativeToolSpec<'a> {
#[serde(rename = "type")]
kind: &'static str,
function: NativeToolFunctionSpec<'a>,
}
#[derive(Debug, Serialize)]
struct NativeToolFunctionSpec<'a> {
name: &'a str,
description: &'a str,
parameters: &'a serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
kind: Option<String>,
function: NativeFunctionCall,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeFunctionCall {
name: String,
arguments: String,
}
/// Multi-part content for vision messages (OpenAI format).
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
enum ApiContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type")]
enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrlDetail },
}
#[derive(Debug, Clone, Serialize)]
struct ImageUrlDetail {
url: String,
}
#[derive(Debug, Deserialize)]
struct ApiChatResponse {
choices: Vec<Choice>,
#[serde(default)]
usage: Option<UsageInfo>,
}
#[derive(Debug, Deserialize)]
struct UsageInfo {
#[serde(default)]
prompt_tokens: Option<u64>,
#[serde(default)]
completion_tokens: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<NativeToolCall>>,
}
// ── Provider ─────────────────────────────────────────────────────
/// GitHub Copilot provider with automatic OAuth and token refresh.
///
/// On first use, prompts the user to visit github.com/login/device.
/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed
/// automatically.
pub struct CopilotProvider {
github_token: Option<String>,
/// Mutex ensures only one caller refreshes tokens at a time,
/// preventing duplicate device flow prompts or redundant API calls.
refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
token_dir: PathBuf,
}
impl CopilotProvider {
pub fn new(github_token: Option<&str>) -> Self {
let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
.map(|dir| dir.config_dir().join("copilot"))
.unwrap_or_else(|| {
// Fall back to a user-specific temp directory to avoid
// shared-directory symlink attacks.
let user = std::env::var("USER")
.or_else(|_| std::env::var("USERNAME"))
.unwrap_or_else(|_| "unknown".to_string());
std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
});
if let Err(err) = std::fs::create_dir_all(&token_dir) {
warn!(
"Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
token_dir
);
} else {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Err(err) =
std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
{
warn!(
"Failed to set Copilot token directory permissions on {:?}: {err}",
token_dir
);
}
}
}
Self {
github_token: github_token
.filter(|token| !token.is_empty())
.map(String::from),
refresh_lock: Arc::new(Mutex::new(None)),
token_dir,
}
}
fn http_client(&self) -> Client {
crate::config::build_runtime_proxy_client_with_timeouts("provider.copilot", 120, 10)
}
/// Required headers for Copilot API requests (editor identification).
const COPILOT_HEADERS: [(&str, &str); 4] = [
("Editor-Version", "vscode/1.85.1"),
("Editor-Plugin-Version", "copilot/1.155.0"),
("User-Agent", "GithubCopilot/1.155.0"),
("Accept", "application/json"),
];
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec<'_>>> {
tools.map(|items| {
items
.iter()
.map(|tool| NativeToolSpec {
kind: "function",
function: NativeToolFunctionSpec {
name: &tool.name,
description: &tool.description,
parameters: &tool.parameters,
},
})
.collect()
})
}
/// Convert message content to API format, with multi-part support for
/// user messages containing `[IMAGE:...]` markers.
fn to_api_content(role: &str, content: &str) -> Option<ApiContent> {
if role != "user" {
return Some(ApiContent::Text(content.to_string()));
}
let (cleaned_text, image_refs) = crate::multimodal::parse_image_markers(content);
if image_refs.is_empty() {
return Some(ApiContent::Text(content.to_string()));
}
let mut parts = Vec::with_capacity(image_refs.len() + 1);
let trimmed = cleaned_text.trim();
if !trimmed.is_empty() {
parts.push(ContentPart::Text {
text: trimmed.to_string(),
});
}
for image_ref in image_refs {
parts.push(ContentPart::ImageUrl {
image_url: ImageUrlDetail { url: image_ref },
});
}
Some(ApiContent::Parts(parts))
}
fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
messages
.iter()
.map(|message| {
if message.role == "assistant" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
if let Some(tool_calls_value) = value.get("tool_calls") {
if let Ok(parsed_calls) =
serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
{
let tool_calls = parsed_calls
.into_iter()
.map(|tool_call| NativeToolCall {
id: Some(tool_call.id),
kind: Some("function".to_string()),
function: NativeFunctionCall {
name: tool_call.name,
arguments: tool_call.arguments,
},
})
.collect::<Vec<_>>();
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(|s| ApiContent::Text(s.to_string()));
return ApiMessage {
role: "assistant".to_string(),
content,
tool_call_id: None,
tool_calls: Some(tool_calls),
};
}
}
}
}
if message.role == "tool" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
let tool_call_id = value
.get("tool_call_id")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(|s| ApiContent::Text(s.to_string()));
return ApiMessage {
role: "tool".to_string(),
content,
tool_call_id,
tool_calls: None,
};
}
}
ApiMessage {
role: message.role.clone(),
content: Self::to_api_content(&message.role, &message.content),
tool_call_id: None,
tool_calls: None,
}
})
.collect()
}
/// Send a chat completions request with required Copilot headers.
async fn send_chat_request(
&self,
messages: Vec<ApiMessage>,
tools: Option<&[ToolSpec]>,
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let (token, endpoint) = self.get_api_key().await?;
let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
let native_tools = Self::convert_tools(tools);
let request = ApiChatRequest {
model: model.to_string(),
messages,
temperature,
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
tools: native_tools,
};
let mut req = self
.http_client()
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&request);
for (header, value) in &Self::COPILOT_HEADERS {
req = req.header(*header, *value);
}
let response = req.send().await?;
if !response.status().is_success() {
return Err(super::api_error("GitHub Copilot", response).await);
}
let api_response: ApiChatResponse = response.json().await?;
let usage = api_response.usage.map(|u| TokenUsage {
input_tokens: u.prompt_tokens,
output_tokens: u.completion_tokens,
cached_input_tokens: None,
});
let choice = api_response
.choices
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?;
let tool_calls = choice
.message
.tool_calls
.unwrap_or_default()
.into_iter()
.map(|tool_call| ProviderToolCall {
id: tool_call
.id
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
name: tool_call.function.name,
arguments: tool_call.function.arguments,
})
.collect();
Ok(ProviderChatResponse {
text: choice.message.content,
tool_calls,
usage,
reasoning_content: None,
})
}
/// Get a valid Copilot API key, refreshing or re-authenticating as needed.
/// Uses a Mutex to ensure only one caller refreshes at a time.
async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
let mut cached = self.refresh_lock.lock().await;
if let Some(cached_key) = cached.as_ref() {
if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at {
return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
}
}
if let Some(info) = self.load_api_key_from_disk().await {
if chrono::Utc::now().timestamp() + 120 < info.expires_at {
let endpoint = info
.endpoints
.as_ref()
.and_then(|e| e.api.clone())
.unwrap_or_else(|| DEFAULT_API.to_string());
let token = info.token;
*cached = Some(CachedApiKey {
token: token.clone(),
api_endpoint: endpoint.clone(),
expires_at: info.expires_at,
});
return Ok((token, endpoint));
}
}
let access_token = self.get_github_access_token().await?;
let api_key_info = self.exchange_for_api_key(&access_token).await?;
self.save_api_key_to_disk(&api_key_info).await;
let endpoint = api_key_info
.endpoints
.as_ref()
.and_then(|e| e.api.clone())
.unwrap_or_else(|| DEFAULT_API.to_string());
*cached = Some(CachedApiKey {
token: api_key_info.token.clone(),
api_endpoint: endpoint.clone(),
expires_at: api_key_info.expires_at,
});
Ok((api_key_info.token, endpoint))
}
/// Get a GitHub access token from config, cache, or device flow.
async fn get_github_access_token(&self) -> anyhow::Result<String> {
if let Some(token) = &self.github_token {
return Ok(token.clone());
}
let access_token_path = self.token_dir.join("access-token");
if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
let token = cached.trim();
if !token.is_empty() {
return Ok(token.to_string());
}
}
let token = self.device_code_login().await?;
write_file_secure(&access_token_path, &token).await;
Ok(token)
}
/// Run GitHub OAuth device code flow.
async fn device_code_login(&self) -> anyhow::Result<String> {
let response: DeviceCodeResponse = self
.http_client()
.post(GITHUB_DEVICE_CODE_URL)
.header("Accept", "application/json")
.json(&serde_json::json!({
"client_id": GITHUB_CLIENT_ID,
"scope": "read:user"
}))
.send()
.await?
.error_for_status()?
.json()
.await?;
let mut poll_interval = Duration::from_secs(response.interval.max(5));
let expires_in = response.expires_in.max(1);
let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
eprintln!(
"\nGitHub Copilot authentication is required.\n\
Visit: {}\n\
Code: {}\n\
Waiting for authorization...\n",
response.verification_uri, response.user_code
);
while tokio::time::Instant::now() < expires_at {
tokio::time::sleep(poll_interval).await;
let token_response: AccessTokenResponse = self
.http_client()
.post(GITHUB_ACCESS_TOKEN_URL)
.header("Accept", "application/json")
.json(&serde_json::json!({
"client_id": GITHUB_CLIENT_ID,
"device_code": response.device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
}))
.send()
.await?
.json()
.await?;
if let Some(token) = token_response.access_token {
eprintln!("Authentication succeeded.\n");
return Ok(token);
}
match token_response.error.as_deref() {
Some("slow_down") => {
poll_interval += Duration::from_secs(5);
}
Some("authorization_pending") | None => {}
Some("expired_token") => {
anyhow::bail!("GitHub device authorization expired")
}
Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
}
}
anyhow::bail!("Timed out waiting for GitHub authorization")
}
/// Exchange a GitHub access token for a Copilot API key.
async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
let mut request = self.http_client().get(GITHUB_API_KEY_URL);
for (header, value) in &Self::COPILOT_HEADERS {
request = request.header(*header, *value);
}
request = request.header("Authorization", format!("token {access_token}"));
let response = request.send().await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
let sanitized = super::sanitize_api_error(&body);
if status.as_u16() == 401 || status.as_u16() == 403 {
let access_token_path = self.token_dir.join("access-token");
tokio::fs::remove_file(&access_token_path).await.ok();
}
anyhow::bail!(
"Failed to get Copilot API key ({status}): {sanitized}. \
Ensure your GitHub account has an active Copilot subscription."
);
}
let info: ApiKeyInfo = response.json().await?;
Ok(info)
}
async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
let path = self.token_dir.join("api-key.json");
let data = tokio::fs::read_to_string(&path).await.ok()?;
serde_json::from_str(&data).ok()
}
async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
let path = self.token_dir.join("api-key.json");
if let Ok(json) = serde_json::to_string_pretty(info) {
write_file_secure(&path, &json).await;
}
}
}
/// Write a file with 0600 permissions (owner read/write only).
/// Uses `spawn_blocking` to avoid blocking the async runtime.
async fn write_file_secure(path: &Path, content: &str) {
let path = path.to_path_buf();
let content = content.to_string();
let result = tokio::task::spawn_blocking(move || {
#[cfg(unix)]
{
use std::io::Write;
use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
let mut file = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(&path)?;
file.write_all(content.as_bytes())?;
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
Ok::<(), std::io::Error>(())
}
#[cfg(not(unix))]
{
std::fs::write(&path, &content)?;
Ok::<(), std::io::Error>(())
}
})
.await;
match result {
Ok(Ok(())) => {}
Ok(Err(err)) => warn!("Failed to write secure file: {err}"),
Err(err) => warn!("Failed to spawn blocking write: {err}"),
}
}
#[async_trait]
impl Provider for CopilotProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let mut messages = Vec::new();
if let Some(system) = system_prompt {
messages.push(ApiMessage {
role: "system".to_string(),
content: Some(ApiContent::Text(system.to_string())),
tool_call_id: None,
tool_calls: None,
});
}
messages.push(ApiMessage {
role: "user".to_string(),
content: Self::to_api_content("user", message),
tool_call_id: None,
tool_calls: None,
});
let response = self
.send_chat_request(messages, None, model, temperature)
.await?;
Ok(response.text.unwrap_or_default())
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let response = self
.send_chat_request(Self::convert_messages(messages), None, model, temperature)
.await?;
Ok(response.text.unwrap_or_default())
}
async fn chat(
&self,
request: ProviderChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
self.send_chat_request(
Self::convert_messages(request.messages),
request.tools,
model,
temperature,
)
.await
}
fn supports_native_tools(&self) -> bool {
true
}
async fn warmup(&self) -> anyhow::Result<()> {
let _ = self.get_api_key().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_without_token() {
let provider = CopilotProvider::new(None);
assert!(provider.github_token.is_none());
}
#[test]
fn new_with_token() {
let provider = CopilotProvider::new(Some("ghp_test"));
assert_eq!(provider.github_token.as_deref(), Some("ghp_test"));
}
#[test]
fn empty_token_treated_as_none() {
let provider = CopilotProvider::new(Some(""));
assert!(provider.github_token.is_none());
}
#[tokio::test]
async fn cache_starts_empty() {
let provider = CopilotProvider::new(None);
let cached = provider.refresh_lock.lock().await;
assert!(cached.is_none());
}
#[test]
fn copilot_headers_include_required_fields() {
let headers = CopilotProvider::COPILOT_HEADERS;
assert!(headers
.iter()
.any(|(header, _)| *header == "Editor-Version"));
assert!(headers
.iter()
.any(|(header, _)| *header == "Editor-Plugin-Version"));
assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
}
#[test]
fn default_interval_and_expiry() {
assert_eq!(default_interval(), 5);
assert_eq!(default_expires_in(), 900);
}
#[test]
fn supports_native_tools() {
let provider = CopilotProvider::new(None);
assert!(provider.supports_native_tools());
}
#[test]
fn api_response_parses_usage() {
let json = r#"{
"choices": [{"message": {"content": "Hello"}}],
"usage": {"prompt_tokens": 200, "completion_tokens": 80}
}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
let usage = resp.usage.unwrap();
assert_eq!(usage.prompt_tokens, Some(200));
assert_eq!(usage.completion_tokens, Some(80));
}
#[test]
fn api_response_parses_without_usage() {
let json = r#"{"choices": [{"message": {"content": "Hello"}}]}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.usage.is_none());
}
#[test]
fn to_api_content_user_with_image_returns_parts() {
let content = "describe this [IMAGE:data:image/png;base64,abc123]";
let result = CopilotProvider::to_api_content("user", content).unwrap();
match result {
ApiContent::Parts(parts) => {
assert_eq!(parts.len(), 2);
assert!(matches!(&parts[0], ContentPart::Text { text } if text == "describe this"));
assert!(
matches!(&parts[1], ContentPart::ImageUrl { image_url } if image_url.url == "data:image/png;base64,abc123")
);
}
ApiContent::Text(_) => {
panic!("expected ApiContent::Parts for user message with image marker")
}
}
}
#[test]
fn to_api_content_user_plain_returns_text() {
let result = CopilotProvider::to_api_content("user", "hello world").unwrap();
assert!(matches!(result, ApiContent::Text(ref s) if s == "hello world"));
}
#[test]
fn to_api_content_non_user_returns_text() {
let result = CopilotProvider::to_api_content("system", "you are helpful").unwrap();
assert!(matches!(result, ApiContent::Text(ref s) if s == "you are helpful"));
let result = CopilotProvider::to_api_content("assistant", "sure").unwrap();
assert!(matches!(result, ApiContent::Text(ref s) if s == "sure"));
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,326 @@
//! Gemini CLI subprocess provider.
//!
//! Integrates with the Gemini CLI, spawning the `gemini` binary
//! as a subprocess for each inference request. This allows using Google's
//! Gemini models via the CLI without an interactive UI session.
//!
//! # Usage
//!
//! The `gemini` binary must be available in `PATH`, or its location must be
//! set via the `GEMINI_CLI_PATH` environment variable.
//!
//! Gemini CLI is invoked as:
//! ```text
//! gemini --print -
//! ```
//! with prompt content written to stdin.
//!
//! # Limitations
//!
//! - **Conversation history**: Only the system prompt (if present) and the last
//! user message are forwarded. Full multi-turn history is not preserved because
//! the CLI accepts a single prompt per invocation.
//! - **System prompt**: The system prompt is prepended to the user message with a
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
//! - **Temperature**: The CLI does not expose a temperature parameter.
//! Only default values are accepted; custom values return an explicit error.
//!
//! # Authentication
//!
//! Authentication is handled by the Gemini CLI itself (its own credential store).
//! No explicit API key is required by this provider.
//!
//! # Environment variables
//!
//! - `GEMINI_CLI_PATH` — override the path to the `gemini` binary (default: `"gemini"`)
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
use async_trait::async_trait;
use std::path::PathBuf;
use tokio::io::AsyncWriteExt;
use tokio::process::Command;
use tokio::time::{timeout, Duration};
/// Environment variable for overriding the path to the `gemini` binary.
pub const GEMINI_CLI_PATH_ENV: &str = "GEMINI_CLI_PATH";
/// Default `gemini` binary name (resolved via `PATH`).
const DEFAULT_GEMINI_CLI_BINARY: &str = "gemini";
/// Model name used to signal "use the provider's own default model".
const DEFAULT_MODEL_MARKER: &str = "default";
/// Gemini CLI requests are bounded to avoid hung subprocesses.
const GEMINI_CLI_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
/// Avoid leaking oversized stderr payloads.
const MAX_GEMINI_CLI_STDERR_CHARS: usize = 512;
/// The CLI does not support sampling controls; allow only baseline defaults.
const GEMINI_CLI_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
const TEMP_EPSILON: f64 = 1e-9;
/// Provider that invokes the Gemini CLI as a subprocess.
///
/// Each inference request spawns a fresh `gemini` process. This is the
/// non-interactive approach: the process handles the prompt and exits.
pub struct GeminiCliProvider {
/// Path to the `gemini` binary.
binary_path: PathBuf,
}
impl GeminiCliProvider {
/// Create a new `GeminiCliProvider`.
///
/// The binary path is resolved from `GEMINI_CLI_PATH` env var if set,
/// otherwise defaults to `"gemini"` (found via `PATH`).
pub fn new() -> Self {
let binary_path = std::env::var(GEMINI_CLI_PATH_ENV)
.ok()
.filter(|path| !path.trim().is_empty())
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from(DEFAULT_GEMINI_CLI_BINARY));
Self { binary_path }
}
/// Returns true if the model argument should be forwarded to the CLI.
fn should_forward_model(model: &str) -> bool {
let trimmed = model.trim();
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
}
fn supports_temperature(temperature: f64) -> bool {
GEMINI_CLI_SUPPORTED_TEMPERATURES
.iter()
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
}
fn validate_temperature(temperature: f64) -> anyhow::Result<()> {
if !temperature.is_finite() {
anyhow::bail!("Gemini CLI provider received non-finite temperature value");
}
if !Self::supports_temperature(temperature) {
anyhow::bail!(
"temperature unsupported by Gemini CLI: {temperature}. \
Supported values: 0.7 or 1.0"
);
}
Ok(())
}
fn redact_stderr(stderr: &[u8]) -> String {
let text = String::from_utf8_lossy(stderr);
let trimmed = text.trim();
if trimmed.is_empty() {
return String::new();
}
if trimmed.chars().count() <= MAX_GEMINI_CLI_STDERR_CHARS {
return trimmed.to_string();
}
let clipped: String = trimmed.chars().take(MAX_GEMINI_CLI_STDERR_CHARS).collect();
format!("{clipped}...")
}
/// Invoke the gemini binary with the given prompt and optional model.
/// Returns the trimmed stdout output as the assistant response.
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
let mut cmd = Command::new(&self.binary_path);
cmd.arg("--print");
if Self::should_forward_model(model) {
cmd.arg("--model").arg(model);
}
// Read prompt from stdin to avoid exposing sensitive content in process args.
cmd.arg("-");
cmd.kill_on_drop(true);
cmd.stdin(std::process::Stdio::piped());
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
let mut child = cmd.spawn().map_err(|err| {
anyhow::anyhow!(
"Failed to spawn Gemini CLI binary at {}: {err}. \
Ensure `gemini` is installed and in PATH, or set GEMINI_CLI_PATH.",
self.binary_path.display()
)
})?;
if let Some(mut stdin) = child.stdin.take() {
stdin.write_all(message.as_bytes()).await.map_err(|err| {
anyhow::anyhow!("Failed to write prompt to Gemini CLI stdin: {err}")
})?;
stdin.shutdown().await.map_err(|err| {
anyhow::anyhow!("Failed to finalize Gemini CLI stdin stream: {err}")
})?;
}
let output = timeout(GEMINI_CLI_REQUEST_TIMEOUT, child.wait_with_output())
.await
.map_err(|_| {
anyhow::anyhow!(
"Gemini CLI request timed out after {:?} (binary: {})",
GEMINI_CLI_REQUEST_TIMEOUT,
self.binary_path.display()
)
})?
.map_err(|err| anyhow::anyhow!("Gemini CLI process failed: {err}"))?;
if !output.status.success() {
let code = output.status.code().unwrap_or(-1);
let stderr_excerpt = Self::redact_stderr(&output.stderr);
let stderr_note = if stderr_excerpt.is_empty() {
String::new()
} else {
format!(" Stderr: {stderr_excerpt}")
};
anyhow::bail!(
"Gemini CLI exited with non-zero status {code}. \
Check that Gemini CLI is authenticated and the CLI is supported.{stderr_note}"
);
}
let text = String::from_utf8(output.stdout)
.map_err(|err| anyhow::anyhow!("Gemini CLI produced non-UTF-8 output: {err}"))?;
Ok(text.trim().to_string())
}
}
impl Default for GeminiCliProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Provider for GeminiCliProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
Self::validate_temperature(temperature)?;
let full_message = match system_prompt {
Some(system) if !system.is_empty() => {
format!("{system}\n\n{message}")
}
_ => message.to_string(),
};
self.invoke_cli(&full_message, model).await
}
async fn chat(
&self,
request: ChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
let text = self
.chat_with_history(request.messages, model, temperature)
.await?;
Ok(ChatResponse {
text: Some(text),
tool_calls: Vec::new(),
usage: Some(TokenUsage::default()),
reasoning_content: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Mutex, OnceLock};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.expect("env lock poisoned")
}
#[test]
fn new_uses_env_override() {
let _guard = env_lock();
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
std::env::set_var(GEMINI_CLI_PATH_ENV, "/usr/local/bin/gemini");
let provider = GeminiCliProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/gemini"));
match orig {
Some(v) => std::env::set_var(GEMINI_CLI_PATH_ENV, v),
None => std::env::remove_var(GEMINI_CLI_PATH_ENV),
}
}
#[test]
fn new_defaults_to_gemini() {
let _guard = env_lock();
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
std::env::remove_var(GEMINI_CLI_PATH_ENV);
let provider = GeminiCliProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("gemini"));
if let Some(v) = orig {
std::env::set_var(GEMINI_CLI_PATH_ENV, v);
}
}
#[test]
fn new_ignores_blank_env_override() {
let _guard = env_lock();
let orig = std::env::var(GEMINI_CLI_PATH_ENV).ok();
std::env::set_var(GEMINI_CLI_PATH_ENV, " ");
let provider = GeminiCliProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("gemini"));
match orig {
Some(v) => std::env::set_var(GEMINI_CLI_PATH_ENV, v),
None => std::env::remove_var(GEMINI_CLI_PATH_ENV),
}
}
#[test]
fn should_forward_model_standard() {
assert!(GeminiCliProvider::should_forward_model("gemini-2.5-pro"));
assert!(GeminiCliProvider::should_forward_model("gemini-2.5-flash"));
}
#[test]
fn should_not_forward_default_model() {
assert!(!GeminiCliProvider::should_forward_model(
DEFAULT_MODEL_MARKER
));
assert!(!GeminiCliProvider::should_forward_model(""));
assert!(!GeminiCliProvider::should_forward_model(" "));
}
#[test]
fn validate_temperature_allows_defaults() {
assert!(GeminiCliProvider::validate_temperature(0.7).is_ok());
assert!(GeminiCliProvider::validate_temperature(1.0).is_ok());
}
#[test]
fn validate_temperature_rejects_custom_value() {
let err = GeminiCliProvider::validate_temperature(0.2).unwrap_err();
assert!(err
.to_string()
.contains("temperature unsupported by Gemini CLI"));
}
#[tokio::test]
async fn invoke_missing_binary_returns_error() {
let provider = GeminiCliProvider {
binary_path: PathBuf::from("/nonexistent/path/to/gemini"),
};
let result = provider.invoke_cli("hello", "default").await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("Failed to spawn Gemini CLI binary"),
"unexpected error message: {msg}"
);
}
}

View File

@@ -0,0 +1,361 @@
//! Zhipu GLM provider with JWT authentication.
//! The GLM API requires JWT tokens generated from the `id.secret` API key format
//! with a custom `sign_type: "SIGN"` header, and uses `/v4/chat/completions`.
use crate::providers::traits::{ChatMessage, Provider};
use async_trait::async_trait;
use reqwest::Client;
use ring::hmac;
use serde::{Deserialize, Serialize};
use std::sync::Mutex;
use std::time::{SystemTime, UNIX_EPOCH};
pub struct GlmProvider {
api_key_id: String,
api_key_secret: String,
base_url: String,
/// Cached JWT token + expiry timestamp (ms)
token_cache: Mutex<Option<(String, u64)>>,
}
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
temperature: f64,
}
#[derive(Debug, Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
content: String,
}
/// Base64url encode without padding (per JWT spec).
fn base64url_encode_bytes(data: &[u8]) -> String {
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::new();
let mut i = 0;
while i < data.len() {
let b0 = data[i] as u32;
let b1 = if i + 1 < data.len() { data[i + 1] as u32 } else { 0 };
let b2 = if i + 2 < data.len() { data[i + 2] as u32 } else { 0 };
let triple = (b0 << 16) | (b1 << 8) | b2;
result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char);
result.push(CHARS[((triple >> 12) & 0x3F) as usize] as char);
if i + 1 < data.len() {
result.push(CHARS[((triple >> 6) & 0x3F) as usize] as char);
}
if i + 2 < data.len() {
result.push(CHARS[(triple & 0x3F) as usize] as char);
}
i += 3;
}
// Convert to base64url: replace + with -, / with _, strip =
result.replace('+', "-").replace('/', "_")
}
fn base64url_encode_str(s: &str) -> String {
base64url_encode_bytes(s.as_bytes())
}
impl GlmProvider {
pub fn new(api_key: Option<&str>) -> Self {
let (id, secret) = api_key
.and_then(|k| k.split_once('.'))
.map(|(id, secret)| (id.to_string(), secret.to_string()))
.unwrap_or_default();
Self {
api_key_id: id,
api_key_secret: secret,
base_url: "https://api.z.ai/api/paas/v4".to_string(),
token_cache: Mutex::new(None),
}
}
fn generate_token(&self) -> anyhow::Result<String> {
if self.api_key_id.is_empty() || self.api_key_secret.is_empty() {
anyhow::bail!(
"GLM API key not set or invalid format. Expected 'id.secret'. \
Run `zeroclaw onboard` or set GLM_API_KEY env var."
);
}
let now_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)?
.as_millis() as u64;
// Check cache (valid for 3 minutes, token expires at 3.5 min)
if let Ok(cache) = self.token_cache.lock() {
if let Some((ref token, expiry)) = *cache {
if now_ms < expiry {
return Ok(token.clone());
}
}
}
let exp_ms = now_ms + 210_000; // 3.5 minutes
// Build JWT manually to include custom sign_type header
// Header: {"alg":"HS256","typ":"JWT","sign_type":"SIGN"}
let header_json = r#"{"alg":"HS256","typ":"JWT","sign_type":"SIGN"}"#;
let header_b64 = base64url_encode_str(header_json);
// Payload: {"api_key":"...","exp":...,"timestamp":...}
let payload_json = format!(
r#"{{"api_key":"{}","exp":{},"timestamp":{}}}"#,
self.api_key_id, exp_ms, now_ms
);
let payload_b64 = base64url_encode_str(&payload_json);
// Sign: HMAC-SHA256(header.payload, secret)
let signing_input = format!("{header_b64}.{payload_b64}");
let key = hmac::Key::new(hmac::HMAC_SHA256, self.api_key_secret.as_bytes());
let signature = hmac::sign(&key, signing_input.as_bytes());
let sig_b64 = base64url_encode_bytes(signature.as_ref());
let token = format!("{signing_input}.{sig_b64}");
// Cache for 3 minutes
if let Ok(mut cache) = self.token_cache.lock() {
*cache = Some((token.clone(), now_ms + 180_000));
}
Ok(token)
}
fn http_client(&self) -> Client {
crate::config::build_runtime_proxy_client_with_timeouts("provider.glm", 120, 10)
}
}
#[async_trait]
impl Provider for GlmProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let token = self.generate_token()?;
let mut messages = Vec::new();
if let Some(sys) = system_prompt {
messages.push(Message {
role: "system".to_string(),
content: sys.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: message.to_string(),
});
let request = ChatRequest {
model: model.to_string(),
messages,
temperature,
};
let url = format!("{}/chat/completions", self.base_url);
let response = self
.http_client()
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let error = response.text().await?;
anyhow::bail!("GLM API error: {error}");
}
let chat_response: ChatResponse = response.json().await?;
chat_response
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from GLM"))
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let token = self.generate_token()?;
let api_messages: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
let request = ChatRequest {
model: model.to_string(),
messages: api_messages,
temperature,
};
let url = format!("{}/chat/completions", self.base_url);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let error = response.text().await?;
anyhow::bail!("GLM API error: {error}");
}
let chat_response: ChatResponse = response.json().await?;
chat_response
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from GLM"))
}
async fn warmup(&self) -> anyhow::Result<()> {
if self.api_key_id.is_empty() || self.api_key_secret.is_empty() {
return Ok(());
}
// Generate and cache a JWT token, establishing TLS to the GLM API.
let token = self.generate_token()?;
let url = format!("{}/chat/completions", self.base_url);
// GET will likely return 405 but establishes the TLS + HTTP/2 connection pool.
let _ = self
.client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
.await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_api_key() {
let p = GlmProvider::new(Some("abc123.secretXYZ"));
assert_eq!(p.api_key_id, "abc123");
assert_eq!(p.api_key_secret, "secretXYZ");
}
#[test]
fn handles_no_key() {
let p = GlmProvider::new(None);
assert!(p.api_key_id.is_empty());
assert!(p.api_key_secret.is_empty());
}
#[test]
fn handles_invalid_key_format() {
let p = GlmProvider::new(Some("no-dot-here"));
assert!(p.api_key_id.is_empty());
assert!(p.api_key_secret.is_empty());
}
#[test]
fn generates_jwt_token() {
let p = GlmProvider::new(Some("testid.testsecret"));
let token = p.generate_token().unwrap();
assert!(!token.is_empty());
// JWT has 3 dot-separated parts
let parts: Vec<&str> = token.split('.').collect();
assert_eq!(parts.len(), 3, "JWT should have 3 parts: {token}");
}
#[test]
fn caches_token() {
let p = GlmProvider::new(Some("testid.testsecret"));
let token1 = p.generate_token().unwrap();
let token2 = p.generate_token().unwrap();
assert_eq!(token1, token2, "Cached token should be reused");
}
#[test]
fn fails_without_key() {
let p = GlmProvider::new(None);
let result = p.generate_token();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("API key not set"));
}
#[tokio::test]
async fn chat_fails_without_key() {
let p = GlmProvider::new(None);
let result = p
.chat_with_system(None, "hello", "glm-4.7", 0.7)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn chat_with_history_fails_without_key() {
let p = GlmProvider::new(None);
let messages = vec![
ChatMessage::system("You are helpful."),
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
ChatMessage::user("What did I say?"),
];
let result = p.chat_with_history(&messages, "glm-4.7", 0.7).await;
assert!(result.is_err());
}
#[test]
fn base64url_no_padding() {
let encoded = base64url_encode_bytes(b"hello");
assert!(!encoded.contains('='));
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
}
#[tokio::test]
async fn warmup_without_key_is_noop() {
let provider = GlmProvider::new(None);
let result = provider.warmup().await;
assert!(result.is_ok());
}
}

View File

@@ -0,0 +1,326 @@
//! KiloCLI subprocess provider.
//!
//! Integrates with the KiloCLI tool, spawning the `kilo` binary
//! as a subprocess for each inference request. This allows using KiloCLI's AI
//! models without an interactive UI session.
//!
//! # Usage
//!
//! The `kilo` binary must be available in `PATH`, or its location must be
//! set via the `KILO_CLI_PATH` environment variable.
//!
//! KiloCLI is invoked as:
//! ```text
//! kilo --print -
//! ```
//! with prompt content written to stdin.
//!
//! # Limitations
//!
//! - **Conversation history**: Only the system prompt (if present) and the last
//! user message are forwarded. Full multi-turn history is not preserved because
//! the CLI accepts a single prompt per invocation.
//! - **System prompt**: The system prompt is prepended to the user message with a
//! blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
//! - **Temperature**: The CLI does not expose a temperature parameter.
//! Only default values are accepted; custom values return an explicit error.
//!
//! # Authentication
//!
//! Authentication is handled by KiloCLI itself (its own credential store).
//! No explicit API key is required by this provider.
//!
//! # Environment variables
//!
//! - `KILO_CLI_PATH` — override the path to the `kilo` binary (default: `"kilo"`)
use crate::providers::traits::{ChatRequest, ChatResponse, Provider, TokenUsage};
use async_trait::async_trait;
use std::path::PathBuf;
use tokio::io::AsyncWriteExt;
use tokio::process::Command;
use tokio::time::{timeout, Duration};
/// Environment variable for overriding the path to the `kilo` binary.
pub const KILO_CLI_PATH_ENV: &str = "KILO_CLI_PATH";
/// Default `kilo` binary name (resolved via `PATH`).
const DEFAULT_KILO_CLI_BINARY: &str = "kilo";
/// Model name used to signal "use the provider's own default model".
const DEFAULT_MODEL_MARKER: &str = "default";
/// KiloCLI requests are bounded to avoid hung subprocesses.
const KILO_CLI_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
/// Avoid leaking oversized stderr payloads.
const MAX_KILO_CLI_STDERR_CHARS: usize = 512;
/// The CLI does not support sampling controls; allow only baseline defaults.
const KILO_CLI_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
const TEMP_EPSILON: f64 = 1e-9;
/// Provider that invokes the KiloCLI as a subprocess.
///
/// Each inference request spawns a fresh `kilo` process. This is the
/// non-interactive approach: the process handles the prompt and exits.
pub struct KiloCliProvider {
/// Path to the `kilo` binary.
binary_path: PathBuf,
}
impl KiloCliProvider {
/// Create a new `KiloCliProvider`.
///
/// The binary path is resolved from `KILO_CLI_PATH` env var if set,
/// otherwise defaults to `"kilo"` (found via `PATH`).
pub fn new() -> Self {
let binary_path = std::env::var(KILO_CLI_PATH_ENV)
.ok()
.filter(|path| !path.trim().is_empty())
.map(PathBuf::from)
.unwrap_or_else(|| PathBuf::from(DEFAULT_KILO_CLI_BINARY));
Self { binary_path }
}
/// Returns true if the model argument should be forwarded to the CLI.
fn should_forward_model(model: &str) -> bool {
let trimmed = model.trim();
!trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
}
fn supports_temperature(temperature: f64) -> bool {
KILO_CLI_SUPPORTED_TEMPERATURES
.iter()
.any(|v| (temperature - v).abs() < TEMP_EPSILON)
}
fn validate_temperature(temperature: f64) -> anyhow::Result<()> {
if !temperature.is_finite() {
anyhow::bail!("KiloCLI provider received non-finite temperature value");
}
if !Self::supports_temperature(temperature) {
anyhow::bail!(
"temperature unsupported by KiloCLI: {temperature}. \
Supported values: 0.7 or 1.0"
);
}
Ok(())
}
fn redact_stderr(stderr: &[u8]) -> String {
let text = String::from_utf8_lossy(stderr);
let trimmed = text.trim();
if trimmed.is_empty() {
return String::new();
}
if trimmed.chars().count() <= MAX_KILO_CLI_STDERR_CHARS {
return trimmed.to_string();
}
let clipped: String = trimmed.chars().take(MAX_KILO_CLI_STDERR_CHARS).collect();
format!("{clipped}...")
}
/// Invoke the kilo binary with the given prompt and optional model.
/// Returns the trimmed stdout output as the assistant response.
async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
let mut cmd = Command::new(&self.binary_path);
cmd.arg("--print");
if Self::should_forward_model(model) {
cmd.arg("--model").arg(model);
}
// Read prompt from stdin to avoid exposing sensitive content in process args.
cmd.arg("-");
cmd.kill_on_drop(true);
cmd.stdin(std::process::Stdio::piped());
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
let mut child = cmd.spawn().map_err(|err| {
anyhow::anyhow!(
"Failed to spawn KiloCLI binary at {}: {err}. \
Ensure `kilo` is installed and in PATH, or set KILO_CLI_PATH.",
self.binary_path.display()
)
})?;
if let Some(mut stdin) = child.stdin.take() {
stdin
.write_all(message.as_bytes())
.await
.map_err(|err| anyhow::anyhow!("Failed to write prompt to KiloCLI stdin: {err}"))?;
stdin
.shutdown()
.await
.map_err(|err| anyhow::anyhow!("Failed to finalize KiloCLI stdin stream: {err}"))?;
}
let output = timeout(KILO_CLI_REQUEST_TIMEOUT, child.wait_with_output())
.await
.map_err(|_| {
anyhow::anyhow!(
"KiloCLI request timed out after {:?} (binary: {})",
KILO_CLI_REQUEST_TIMEOUT,
self.binary_path.display()
)
})?
.map_err(|err| anyhow::anyhow!("KiloCLI process failed: {err}"))?;
if !output.status.success() {
let code = output.status.code().unwrap_or(-1);
let stderr_excerpt = Self::redact_stderr(&output.stderr);
let stderr_note = if stderr_excerpt.is_empty() {
String::new()
} else {
format!(" Stderr: {stderr_excerpt}")
};
anyhow::bail!(
"KiloCLI exited with non-zero status {code}. \
Check that KiloCLI is authenticated and the CLI is supported.{stderr_note}"
);
}
let text = String::from_utf8(output.stdout)
.map_err(|err| anyhow::anyhow!("KiloCLI produced non-UTF-8 output: {err}"))?;
Ok(text.trim().to_string())
}
}
impl Default for KiloCliProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Provider for KiloCliProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
Self::validate_temperature(temperature)?;
let full_message = match system_prompt {
Some(system) if !system.is_empty() => {
format!("{system}\n\n{message}")
}
_ => message.to_string(),
};
self.invoke_cli(&full_message, model).await
}
async fn chat(
&self,
request: ChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
let text = self
.chat_with_history(request.messages, model, temperature)
.await?;
Ok(ChatResponse {
text: Some(text),
tool_calls: Vec::new(),
usage: Some(TokenUsage::default()),
reasoning_content: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Mutex, OnceLock};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
.lock()
.expect("env lock poisoned")
}
#[test]
fn new_uses_env_override() {
let _guard = env_lock();
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
std::env::set_var(KILO_CLI_PATH_ENV, "/usr/local/bin/kilo");
let provider = KiloCliProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/kilo"));
match orig {
Some(v) => std::env::set_var(KILO_CLI_PATH_ENV, v),
None => std::env::remove_var(KILO_CLI_PATH_ENV),
}
}
#[test]
fn new_defaults_to_kilo() {
let _guard = env_lock();
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
std::env::remove_var(KILO_CLI_PATH_ENV);
let provider = KiloCliProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("kilo"));
if let Some(v) = orig {
std::env::set_var(KILO_CLI_PATH_ENV, v);
}
}
#[test]
fn new_ignores_blank_env_override() {
let _guard = env_lock();
let orig = std::env::var(KILO_CLI_PATH_ENV).ok();
std::env::set_var(KILO_CLI_PATH_ENV, " ");
let provider = KiloCliProvider::new();
assert_eq!(provider.binary_path, PathBuf::from("kilo"));
match orig {
Some(v) => std::env::set_var(KILO_CLI_PATH_ENV, v),
None => std::env::remove_var(KILO_CLI_PATH_ENV),
}
}
#[test]
fn should_forward_model_standard() {
assert!(KiloCliProvider::should_forward_model("some-model"));
assert!(KiloCliProvider::should_forward_model("gpt-4o"));
}
#[test]
fn should_not_forward_default_model() {
assert!(!KiloCliProvider::should_forward_model(DEFAULT_MODEL_MARKER));
assert!(!KiloCliProvider::should_forward_model(""));
assert!(!KiloCliProvider::should_forward_model(" "));
}
#[test]
fn validate_temperature_allows_defaults() {
assert!(KiloCliProvider::validate_temperature(0.7).is_ok());
assert!(KiloCliProvider::validate_temperature(1.0).is_ok());
}
#[test]
fn validate_temperature_rejects_custom_value() {
let err = KiloCliProvider::validate_temperature(0.2).unwrap_err();
assert!(err
.to_string()
.contains("temperature unsupported by KiloCLI"));
}
#[tokio::test]
async fn invoke_missing_binary_returns_error() {
let provider = KiloCliProvider {
binary_path: PathBuf::from("/nonexistent/path/to/kilo"),
};
let result = provider.invoke_cli("hello", "default").await;
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(
msg.contains("Failed to spawn KiloCLI binary"),
"unexpected error message: {msg}"
);
}
}

3625
third_party/zeroclaw/src/providers/mod.rs vendored Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,391 @@
//! Telnyx AI inference provider.
//!
//! Telnyx provides AI inference through an OpenAI-compatible API at
//! https://api.telnyx.com/v2/ai with access to 53+ models including
//! GPT-4o, Claude, Llama, Mistral, and more.
//!
//! # Configuration
//!
//! Set the `TELNYX_API_KEY` environment variable or configure in `config.toml`:
//!
//! ```toml
//! default_provider = "telnyx"
//! default_model = "openai/gpt-4o"
//! ```
use crate::providers::traits::{ChatMessage, Provider};
use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
/// Telnyx AI inference provider.
///
/// Uses the OpenAI-compatible chat completions API at `/v2/ai/chat/completions`.
/// Supports 53+ models including OpenAI, Anthropic (via API), Meta Llama,
/// Mistral, and more.
///
/// # Example
///
/// ```rust,ignore
/// use zeroclaw::providers::telnyx::TelnyxProvider;
/// use zeroclaw::providers::Provider;
///
/// let provider = TelnyxProvider::new(Some("your-api-key"));
/// let response = provider.chat("Hello!", "openai/gpt-4o", 0.7).await?;
/// ```
pub struct TelnyxProvider {
/// Telnyx API key
api_key: Option<String>,
/// HTTP client for API requests
client: Client,
}
impl TelnyxProvider {
/// Telnyx AI API base URL
const BASE_URL: &'static str = "https://api.telnyx.com/v2/ai";
/// Create a new Telnyx AI provider.
///
/// The API key can be provided directly or will be resolved from:
/// 1. `TELNYX_API_KEY` environment variable
/// 2. `ZEROCLAW_API_KEY` environment variable (fallback)
pub fn new(api_key: Option<&str>) -> Self {
let resolved_key = resolve_telnyx_api_key(api_key);
Self {
api_key: resolved_key,
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new()),
}
}
/// Create a provider with a custom base URL (for testing or proxies).
pub fn with_base_url(api_key: Option<&str>, _base_url: &str) -> Self {
// Note: custom base URL support for testing
Self::new(api_key)
}
/// List available models from Telnyx AI.
///
/// Returns a list of model IDs that can be used with the chat API.
pub async fn list_models(&self) -> anyhow::Result<Vec<String>> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
anyhow::anyhow!("Telnyx API key not set. Set TELNYX_API_KEY environment variable.")
})?;
let response = self
.client
.get(format!("{}/models", Self::BASE_URL))
.header("Authorization", format!("Bearer {}", api_key))
.send()
.await?;
if !response.status().is_success() {
let error = response.text().await?;
anyhow::bail!("Failed to list Telnyx models: {}", error);
}
let models_response: ModelsResponse = response.json().await?;
Ok(models_response.data.into_iter().map(|m| m.id).collect())
}
/// Build the chat completions URL
fn chat_url(&self) -> String {
format!("{}/chat/completions", Self::BASE_URL)
}
}
/// Resolve Telnyx API key from parameter or environment.
fn resolve_telnyx_api_key(api_key: Option<&str>) -> Option<String> {
if let Some(key) = api_key.map(str::trim).filter(|k| !k.is_empty()) {
return Some(key.to_string());
}
// Try Telnyx-specific env var first
if let Ok(key) = std::env::var("TELNYX_API_KEY") {
let key = key.trim();
if !key.is_empty() {
return Some(key.to_string());
}
}
// Fall back to generic env vars
for env_var in ["ZEROCLAW_API_KEY", "API_KEY"] {
if let Ok(key) = std::env::var(env_var) {
let key = key.trim();
if !key.is_empty() {
return Some(key.to_string());
}
}
}
None
}
/// Response from the /models endpoint
#[derive(Debug, Deserialize)]
struct ModelsResponse {
data: Vec<ModelInfo>,
}
#[derive(Debug, Deserialize)]
struct ModelInfo {
id: String,
}
/// Request body for chat completions
#[derive(Debug, serde::Serialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
temperature: f64,
}
#[derive(Debug, serde::Serialize)]
struct Message {
role: String,
content: String,
}
/// Response from chat completions API
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
content: String,
}
#[async_trait]
impl Provider for TelnyxProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"Telnyx API key not set. Set TELNYX_API_KEY environment variable or run `zeroclaw onboard`."
)
})?;
let mut messages = Vec::new();
if let Some(sys) = system_prompt {
messages.push(Message {
role: "system".to_string(),
content: sys.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: message.to_string(),
});
let request = ChatRequest {
model: model.to_string(),
messages,
temperature,
};
let response = self
.client
.post(self.chat_url())
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error = response.text().await?;
let sanitized = super::sanitize_api_error(&error);
anyhow::bail!("Telnyx API error ({}): {}", status, sanitized);
}
let chat_response: ChatResponse = response.json().await?;
chat_response
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from Telnyx"))
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"Telnyx API key not set. Set TELNYX_API_KEY environment variable or run `zeroclaw onboard`."
)
})?;
let api_messages: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
let request = ChatRequest {
model: model.to_string(),
messages: api_messages,
temperature,
};
let response = self
.client
.post(self.chat_url())
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error = response.text().await?;
let sanitized = super::sanitize_api_error(&error);
anyhow::bail!("Telnyx API error ({}): {}", status, sanitized);
}
let chat_response: ChatResponse = response.json().await?;
chat_response
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from Telnyx"))
}
async fn warmup(&self) -> anyhow::Result<()> {
// Pre-warm the connection pool
let _ = self
.client
.get(format!("{}/models", Self::BASE_URL))
.send()
.await;
Ok(())
}
}
/// Popular Telnyx AI models for easy reference.
pub mod models {
/// OpenAI GPT-4o (recommended for most tasks)
pub const GPT_4O: &str = "openai/gpt-4o";
/// OpenAI GPT-4o Mini (fast and cost-effective)
pub const GPT_4O_MINI: &str = "openai/gpt-4o-mini";
/// OpenAI GPT-4 Turbo
pub const GPT_4_TURBO: &str = "openai/gpt-4-turbo";
/// Anthropic Claude 3.5 Sonnet (via Telnyx proxy)
pub const CLAUDE_3_5_SONNET: &str = "anthropic/claude-3.5-sonnet";
/// Meta Llama 3.1 70B Instruct
pub const LLAMA_3_1_70B: &str = "meta-llama/llama-3.1-70b-instruct";
/// Meta Llama 3.1 8B Instruct (fast)
pub const LLAMA_3_1_8B: &str = "meta-llama/llama-3.1-8b-instruct";
/// Mistral Large
pub const MISTRAL_LARGE: &str = "mistralai/mistral-large";
/// Mistral Small (fast)
pub const MISTRAL_SMALL: &str = "mistralai/mistral-small";
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn creates_provider_with_key() {
let provider = TelnyxProvider::new(Some("test-key"));
assert!(provider.api_key.is_some());
}
#[test]
fn creates_provider_without_key() {
let _provider = TelnyxProvider::new(None);
// Will be None if env vars not set
}
#[test]
fn model_constants_are_valid() {
assert!(models::GPT_4O.starts_with("openai/"));
assert!(models::CLAUDE_3_5_SONNET.starts_with("anthropic/"));
assert!(models::LLAMA_3_1_70B.starts_with("meta-llama/"));
assert!(models::MISTRAL_LARGE.starts_with("mistralai/"));
}
#[test]
fn resolve_key_from_parameter() {
let key = resolve_telnyx_api_key(Some("direct-key"));
assert_eq!(key, Some("direct-key".to_string()));
}
#[test]
fn resolve_key_trims_whitespace() {
let key = resolve_telnyx_api_key(Some(" spaced-key "));
assert_eq!(key, Some("spaced-key".to_string()));
}
#[test]
fn models_response_deserializes() {
let json = r#"{
"data": [
{"id": "openai/gpt-4o"},
{"id": "anthropic/claude-3.5-sonnet"}
]
}"#;
let response: ModelsResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.data.len(), 2);
assert_eq!(response.data[0].id, "openai/gpt-4o");
}
#[test]
fn chat_request_serializes() {
let req = ChatRequest {
model: "openai/gpt-4o".to_string(),
messages: vec![
Message {
role: "system".to_string(),
content: "You are helpful.".to_string(),
},
Message {
role: "user".to_string(),
content: "Hello".to_string(),
},
],
temperature: 0.7,
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("openai/gpt-4o"));
assert!(json.contains("system"));
assert!(json.contains("user"));
}
#[test]
fn chat_response_deserializes() {
let json = r#"{"choices":[{"message":{"content":"Hello from Telnyx!"}}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.choices[0].message.content, "Hello from Telnyx!");
}
}

File diff suppressed because it is too large Load Diff