feat: add deepseek provider scaffolding
This commit is contained in:
1171
Cargo.lock
generated
1171
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -6,6 +6,7 @@ edition = "2021"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
hex = "0.4"
|
hex = "0.4"
|
||||||
hmac = "0.12"
|
hmac = "0.12"
|
||||||
|
reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls"] }
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1"
|
serde_json = "1"
|
||||||
sha2 = "0.10"
|
sha2 = "0.10"
|
||||||
|
|||||||
3
src/config/mod.rs
Normal file
3
src/config/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
mod settings;
|
||||||
|
|
||||||
|
pub use settings::{ConfigError, DeepSeekSettings};
|
||||||
46
src/config/settings.rs
Normal file
46
src/config/settings.rs
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
const DEFAULT_DEEPSEEK_BASE_URL: &str = "https://api.deepseek.com";
|
||||||
|
const DEFAULT_DEEPSEEK_MODEL: &str = "deepseek-chat";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
pub struct DeepSeekSettings {
|
||||||
|
pub api_key: String,
|
||||||
|
pub base_url: String,
|
||||||
|
pub model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DeepSeekSettings {
|
||||||
|
pub fn from_env() -> Result<Self, ConfigError> {
|
||||||
|
let api_key = std::env::var("DEEPSEEK_API_KEY")
|
||||||
|
.map_err(|_| ConfigError::MissingEnv("DEEPSEEK_API_KEY"))?;
|
||||||
|
let base_url = std::env::var("DEEPSEEK_BASE_URL")
|
||||||
|
.unwrap_or_else(|_| DEFAULT_DEEPSEEK_BASE_URL.to_string());
|
||||||
|
let model =
|
||||||
|
std::env::var("DEEPSEEK_MODEL").unwrap_or_else(|_| DEFAULT_DEEPSEEK_MODEL.to_string());
|
||||||
|
|
||||||
|
if api_key.trim().is_empty() {
|
||||||
|
return Err(ConfigError::EmptyValue("DEEPSEEK_API_KEY"));
|
||||||
|
}
|
||||||
|
if base_url.trim().is_empty() {
|
||||||
|
return Err(ConfigError::EmptyValue("DEEPSEEK_BASE_URL"));
|
||||||
|
}
|
||||||
|
if model.trim().is_empty() {
|
||||||
|
return Err(ConfigError::EmptyValue("DEEPSEEK_MODEL"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
api_key,
|
||||||
|
base_url,
|
||||||
|
model,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error, Clone, PartialEq, Eq)]
|
||||||
|
pub enum ConfigError {
|
||||||
|
#[error("missing environment variable: {0}")]
|
||||||
|
MissingEnv(&'static str),
|
||||||
|
#[error("environment variable must not be empty: {0}")]
|
||||||
|
EmptyValue(&'static str),
|
||||||
|
}
|
||||||
@@ -1,4 +1,6 @@
|
|||||||
pub mod agent;
|
pub mod agent;
|
||||||
|
pub mod config;
|
||||||
|
pub mod llm;
|
||||||
pub mod pipe;
|
pub mod pipe;
|
||||||
pub mod security;
|
pub mod security;
|
||||||
|
|
||||||
|
|||||||
154
src/llm/deepseek.rs
Normal file
154
src/llm/deepseek.rs
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
use reqwest::blocking::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
use crate::config::DeepSeekSettings;
|
||||||
|
use crate::llm::provider::{ChatMessage, LlmError, LlmProvider, ToolDefinition, ToolFunctionCall};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DeepSeekProvider {
|
||||||
|
settings: DeepSeekSettings,
|
||||||
|
client: Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DeepSeekProvider {
|
||||||
|
pub fn from_env() -> Result<Self, LlmError> {
|
||||||
|
Ok(Self::new(DeepSeekSettings::from_env()?))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(settings: DeepSeekSettings) -> Self {
|
||||||
|
Self {
|
||||||
|
settings,
|
||||||
|
client: Client::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn settings(&self) -> &DeepSeekSettings {
|
||||||
|
&self.settings
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_chat_request(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
tools: &[ToolDefinition],
|
||||||
|
) -> DeepSeekChatRequest {
|
||||||
|
DeepSeekChatRequest {
|
||||||
|
model: self.settings.model.clone(),
|
||||||
|
messages: messages.to_vec(),
|
||||||
|
tools: if tools.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.map(|tool| DeepSeekToolDefinition {
|
||||||
|
tool_type: "function".to_string(),
|
||||||
|
function: DeepSeekFunctionDefinition {
|
||||||
|
name: tool.name,
|
||||||
|
description: tool.description,
|
||||||
|
parameters: tool.parameters,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
stream: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn completions_url(&self) -> String {
|
||||||
|
format!(
|
||||||
|
"{}/chat/completions",
|
||||||
|
self.settings.base_url.trim_end_matches('/')
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LlmProvider for DeepSeekProvider {
|
||||||
|
fn chat(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
tools: &[ToolDefinition],
|
||||||
|
) -> Result<Vec<ToolFunctionCall>, LlmError> {
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(self.completions_url())
|
||||||
|
.bearer_auth(&self.settings.api_key)
|
||||||
|
.json(&self.build_chat_request(messages, tools))
|
||||||
|
.send()?
|
||||||
|
.error_for_status()?
|
||||||
|
.json::<DeepSeekChatResponse>()?;
|
||||||
|
|
||||||
|
let tool_calls = response
|
||||||
|
.choices
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.and_then(|choice| choice.message.tool_calls)
|
||||||
|
.ok_or(LlmError::NoToolCalls)?;
|
||||||
|
|
||||||
|
tool_calls
|
||||||
|
.into_iter()
|
||||||
|
.map(|call| {
|
||||||
|
let arguments = serde_json::from_str(&call.function.arguments)
|
||||||
|
.map_err(|err| LlmError::InvalidToolArguments(err.to_string()))?;
|
||||||
|
Ok(ToolFunctionCall {
|
||||||
|
id: call.id,
|
||||||
|
name: call.function.name,
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||||
|
pub struct DeepSeekChatRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub messages: Vec<ChatMessage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tools: Option<Vec<DeepSeekToolDefinition>>,
|
||||||
|
pub stream: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||||
|
pub struct DeepSeekToolDefinition {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub tool_type: String,
|
||||||
|
pub function: DeepSeekFunctionDefinition,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||||
|
pub struct DeepSeekFunctionDefinition {
|
||||||
|
pub name: String,
|
||||||
|
pub description: String,
|
||||||
|
pub parameters: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct DeepSeekChatResponse {
|
||||||
|
choices: Vec<DeepSeekChoice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct DeepSeekChoice {
|
||||||
|
message: DeepSeekResponseMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct DeepSeekResponseMessage {
|
||||||
|
#[serde(default)]
|
||||||
|
tool_calls: Option<Vec<DeepSeekResponseToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct DeepSeekResponseToolCall {
|
||||||
|
id: String,
|
||||||
|
function: DeepSeekResponseFunctionCall,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct DeepSeekResponseFunctionCall {
|
||||||
|
name: String,
|
||||||
|
arguments: String,
|
||||||
|
}
|
||||||
5
src/llm/mod.rs
Normal file
5
src/llm/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
mod deepseek;
|
||||||
|
mod provider;
|
||||||
|
|
||||||
|
pub use deepseek::{DeepSeekChatRequest, DeepSeekProvider};
|
||||||
|
pub use provider::{ChatMessage, LlmError, LlmProvider, ToolDefinition, ToolFunctionCall};
|
||||||
45
src/llm/provider.rs
Normal file
45
src/llm/provider.rs
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
pub struct ChatMessage {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||||
|
pub struct ToolDefinition {
|
||||||
|
pub name: String,
|
||||||
|
pub description: String,
|
||||||
|
pub parameters: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub struct ToolFunctionCall {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub arguments: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait LlmProvider {
|
||||||
|
fn chat(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
tools: &[ToolDefinition],
|
||||||
|
) -> Result<Vec<ToolFunctionCall>, LlmError>;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum LlmError {
|
||||||
|
#[error(transparent)]
|
||||||
|
Config(#[from] crate::config::ConfigError),
|
||||||
|
#[error(transparent)]
|
||||||
|
Http(#[from] reqwest::Error),
|
||||||
|
#[error(transparent)]
|
||||||
|
Json(#[from] serde_json::Error),
|
||||||
|
#[error("llm returned no tool calls")]
|
||||||
|
NoToolCalls,
|
||||||
|
#[error("llm returned malformed tool arguments: {0}")]
|
||||||
|
InvalidToolArguments(String),
|
||||||
|
}
|
||||||
67
tests/deepseek_provider_test.rs
Normal file
67
tests/deepseek_provider_test.rs
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
use std::sync::{Mutex, OnceLock};
|
||||||
|
|
||||||
|
use serde_json::json;
|
||||||
|
use sgclaw::config::DeepSeekSettings;
|
||||||
|
use sgclaw::llm::{ChatMessage, DeepSeekProvider, ToolDefinition};
|
||||||
|
|
||||||
|
fn env_lock() -> &'static Mutex<()> {
|
||||||
|
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
|
||||||
|
LOCK.get_or_init(|| Mutex::new(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn deepseek_settings_load_defaults_from_env() {
|
||||||
|
let _guard = env_lock().lock().unwrap();
|
||||||
|
std::env::set_var("DEEPSEEK_API_KEY", "test-key");
|
||||||
|
std::env::remove_var("DEEPSEEK_BASE_URL");
|
||||||
|
std::env::remove_var("DEEPSEEK_MODEL");
|
||||||
|
|
||||||
|
let settings = DeepSeekSettings::from_env().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(settings.api_key, "test-key");
|
||||||
|
assert_eq!(settings.base_url, "https://api.deepseek.com");
|
||||||
|
assert_eq!(settings.model, "deepseek-chat");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn deepseek_request_shape_matches_openai_compatible_chat_format() {
|
||||||
|
let provider = DeepSeekProvider::new(DeepSeekSettings {
|
||||||
|
api_key: "test-key".to_string(),
|
||||||
|
base_url: "https://api.deepseek.com".to_string(),
|
||||||
|
model: "deepseek-chat".to_string(),
|
||||||
|
});
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage {
|
||||||
|
role: "system".to_string(),
|
||||||
|
content: "You are sgClaw.".to_string(),
|
||||||
|
},
|
||||||
|
ChatMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "打开百度搜索天气".to_string(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
let tools = vec![ToolDefinition {
|
||||||
|
name: "browser_action".to_string(),
|
||||||
|
description: "Execute browser actions".to_string(),
|
||||||
|
parameters: json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"action": { "type": "string" }
|
||||||
|
},
|
||||||
|
"required": ["action"]
|
||||||
|
}),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let request = provider.build_chat_request(&messages, &tools);
|
||||||
|
let serialized = serde_json::to_value(&request).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(serialized["model"], "deepseek-chat");
|
||||||
|
assert_eq!(serialized["stream"], false);
|
||||||
|
assert_eq!(serialized["messages"][0]["role"], "system");
|
||||||
|
assert_eq!(serialized["messages"][1]["content"], "打开百度搜索天气");
|
||||||
|
assert_eq!(serialized["tools"][0]["type"], "function");
|
||||||
|
assert_eq!(
|
||||||
|
serialized["tools"][0]["function"]["name"],
|
||||||
|
"browser_action"
|
||||||
|
);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user