feat: add deepseek provider scaffolding
This commit is contained in:
1171
Cargo.lock
generated
1171
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@ edition = "2021"
|
||||
[dependencies]
|
||||
hex = "0.4"
|
||||
hmac = "0.12"
|
||||
reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
sha2 = "0.10"
|
||||
|
||||
3
src/config/mod.rs
Normal file
3
src/config/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod settings;
|
||||
|
||||
pub use settings::{ConfigError, DeepSeekSettings};
|
||||
46
src/config/settings.rs
Normal file
46
src/config/settings.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
use thiserror::Error;
|
||||
|
||||
const DEFAULT_DEEPSEEK_BASE_URL: &str = "https://api.deepseek.com";
|
||||
const DEFAULT_DEEPSEEK_MODEL: &str = "deepseek-chat";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct DeepSeekSettings {
|
||||
pub api_key: String,
|
||||
pub base_url: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl DeepSeekSettings {
|
||||
pub fn from_env() -> Result<Self, ConfigError> {
|
||||
let api_key = std::env::var("DEEPSEEK_API_KEY")
|
||||
.map_err(|_| ConfigError::MissingEnv("DEEPSEEK_API_KEY"))?;
|
||||
let base_url = std::env::var("DEEPSEEK_BASE_URL")
|
||||
.unwrap_or_else(|_| DEFAULT_DEEPSEEK_BASE_URL.to_string());
|
||||
let model =
|
||||
std::env::var("DEEPSEEK_MODEL").unwrap_or_else(|_| DEFAULT_DEEPSEEK_MODEL.to_string());
|
||||
|
||||
if api_key.trim().is_empty() {
|
||||
return Err(ConfigError::EmptyValue("DEEPSEEK_API_KEY"));
|
||||
}
|
||||
if base_url.trim().is_empty() {
|
||||
return Err(ConfigError::EmptyValue("DEEPSEEK_BASE_URL"));
|
||||
}
|
||||
if model.trim().is_empty() {
|
||||
return Err(ConfigError::EmptyValue("DEEPSEEK_MODEL"));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
api_key,
|
||||
base_url,
|
||||
model,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error, Clone, PartialEq, Eq)]
|
||||
pub enum ConfigError {
|
||||
#[error("missing environment variable: {0}")]
|
||||
MissingEnv(&'static str),
|
||||
#[error("environment variable must not be empty: {0}")]
|
||||
EmptyValue(&'static str),
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
pub mod agent;
|
||||
pub mod config;
|
||||
pub mod llm;
|
||||
pub mod pipe;
|
||||
pub mod security;
|
||||
|
||||
|
||||
154
src/llm/deepseek.rs
Normal file
154
src/llm/deepseek.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
use reqwest::blocking::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::config::DeepSeekSettings;
|
||||
use crate::llm::provider::{ChatMessage, LlmError, LlmProvider, ToolDefinition, ToolFunctionCall};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DeepSeekProvider {
|
||||
settings: DeepSeekSettings,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl DeepSeekProvider {
|
||||
pub fn from_env() -> Result<Self, LlmError> {
|
||||
Ok(Self::new(DeepSeekSettings::from_env()?))
|
||||
}
|
||||
|
||||
pub fn new(settings: DeepSeekSettings) -> Self {
|
||||
Self {
|
||||
settings,
|
||||
client: Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn settings(&self) -> &DeepSeekSettings {
|
||||
&self.settings
|
||||
}
|
||||
|
||||
pub fn build_chat_request(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
tools: &[ToolDefinition],
|
||||
) -> DeepSeekChatRequest {
|
||||
DeepSeekChatRequest {
|
||||
model: self.settings.model.clone(),
|
||||
messages: messages.to_vec(),
|
||||
tools: if tools.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
tools
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|tool| DeepSeekToolDefinition {
|
||||
tool_type: "function".to_string(),
|
||||
function: DeepSeekFunctionDefinition {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
},
|
||||
stream: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn completions_url(&self) -> String {
|
||||
format!(
|
||||
"{}/chat/completions",
|
||||
self.settings.base_url.trim_end_matches('/')
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl LlmProvider for DeepSeekProvider {
|
||||
fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
tools: &[ToolDefinition],
|
||||
) -> Result<Vec<ToolFunctionCall>, LlmError> {
|
||||
let response = self
|
||||
.client
|
||||
.post(self.completions_url())
|
||||
.bearer_auth(&self.settings.api_key)
|
||||
.json(&self.build_chat_request(messages, tools))
|
||||
.send()?
|
||||
.error_for_status()?
|
||||
.json::<DeepSeekChatResponse>()?;
|
||||
|
||||
let tool_calls = response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.and_then(|choice| choice.message.tool_calls)
|
||||
.ok_or(LlmError::NoToolCalls)?;
|
||||
|
||||
tool_calls
|
||||
.into_iter()
|
||||
.map(|call| {
|
||||
let arguments = serde_json::from_str(&call.function.arguments)
|
||||
.map_err(|err| LlmError::InvalidToolArguments(err.to_string()))?;
|
||||
Ok(ToolFunctionCall {
|
||||
id: call.id,
|
||||
name: call.function.name,
|
||||
arguments,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub struct DeepSeekChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<DeepSeekToolDefinition>>,
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub struct DeepSeekToolDefinition {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
pub function: DeepSeekFunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub struct DeepSeekFunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeepSeekChatResponse {
|
||||
choices: Vec<DeepSeekChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeepSeekChoice {
|
||||
message: DeepSeekResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeepSeekResponseMessage {
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<DeepSeekResponseToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeepSeekResponseToolCall {
|
||||
id: String,
|
||||
function: DeepSeekResponseFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeepSeekResponseFunctionCall {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
5
src/llm/mod.rs
Normal file
5
src/llm/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod deepseek;
|
||||
mod provider;
|
||||
|
||||
pub use deepseek::{DeepSeekChatRequest, DeepSeekProvider};
|
||||
pub use provider::{ChatMessage, LlmError, LlmProvider, ToolDefinition, ToolFunctionCall};
|
||||
45
src/llm/provider.rs
Normal file
45
src/llm/provider.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ToolDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ToolFunctionCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: Value,
|
||||
}
|
||||
|
||||
pub trait LlmProvider {
|
||||
fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
tools: &[ToolDefinition],
|
||||
) -> Result<Vec<ToolFunctionCall>, LlmError>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LlmError {
|
||||
#[error(transparent)]
|
||||
Config(#[from] crate::config::ConfigError),
|
||||
#[error(transparent)]
|
||||
Http(#[from] reqwest::Error),
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
#[error("llm returned no tool calls")]
|
||||
NoToolCalls,
|
||||
#[error("llm returned malformed tool arguments: {0}")]
|
||||
InvalidToolArguments(String),
|
||||
}
|
||||
67
tests/deepseek_provider_test.rs
Normal file
67
tests/deepseek_provider_test.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
|
||||
use serde_json::json;
|
||||
use sgclaw::config::DeepSeekSettings;
|
||||
use sgclaw::llm::{ChatMessage, DeepSeekProvider, ToolDefinition};
|
||||
|
||||
fn env_lock() -> &'static Mutex<()> {
|
||||
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
LOCK.get_or_init(|| Mutex::new(()))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deepseek_settings_load_defaults_from_env() {
|
||||
let _guard = env_lock().lock().unwrap();
|
||||
std::env::set_var("DEEPSEEK_API_KEY", "test-key");
|
||||
std::env::remove_var("DEEPSEEK_BASE_URL");
|
||||
std::env::remove_var("DEEPSEEK_MODEL");
|
||||
|
||||
let settings = DeepSeekSettings::from_env().unwrap();
|
||||
|
||||
assert_eq!(settings.api_key, "test-key");
|
||||
assert_eq!(settings.base_url, "https://api.deepseek.com");
|
||||
assert_eq!(settings.model, "deepseek-chat");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn deepseek_request_shape_matches_openai_compatible_chat_format() {
|
||||
let provider = DeepSeekProvider::new(DeepSeekSettings {
|
||||
api_key: "test-key".to_string(),
|
||||
base_url: "https://api.deepseek.com".to_string(),
|
||||
model: "deepseek-chat".to_string(),
|
||||
});
|
||||
let messages = vec![
|
||||
ChatMessage {
|
||||
role: "system".to_string(),
|
||||
content: "You are sgClaw.".to_string(),
|
||||
},
|
||||
ChatMessage {
|
||||
role: "user".to_string(),
|
||||
content: "打开百度搜索天气".to_string(),
|
||||
},
|
||||
];
|
||||
let tools = vec![ToolDefinition {
|
||||
name: "browser_action".to_string(),
|
||||
description: "Execute browser actions".to_string(),
|
||||
parameters: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": { "type": "string" }
|
||||
},
|
||||
"required": ["action"]
|
||||
}),
|
||||
}];
|
||||
|
||||
let request = provider.build_chat_request(&messages, &tools);
|
||||
let serialized = serde_json::to_value(&request).unwrap();
|
||||
|
||||
assert_eq!(serialized["model"], "deepseek-chat");
|
||||
assert_eq!(serialized["stream"], false);
|
||||
assert_eq!(serialized["messages"][0]["role"], "system");
|
||||
assert_eq!(serialized["messages"][1]["content"], "打开百度搜索天气");
|
||||
assert_eq!(serialized["tools"][0]["type"], "function");
|
||||
assert_eq!(
|
||||
serialized["tools"][0]["function"]["name"],
|
||||
"browser_action"
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user