feat: add deepseek provider scaffolding
This commit is contained in:
3
src/config/mod.rs
Normal file
3
src/config/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
mod settings;
|
||||
|
||||
pub use settings::{ConfigError, DeepSeekSettings};
|
||||
46
src/config/settings.rs
Normal file
46
src/config/settings.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
use thiserror::Error;
|
||||
|
||||
const DEFAULT_DEEPSEEK_BASE_URL: &str = "https://api.deepseek.com";
|
||||
const DEFAULT_DEEPSEEK_MODEL: &str = "deepseek-chat";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct DeepSeekSettings {
|
||||
pub api_key: String,
|
||||
pub base_url: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
impl DeepSeekSettings {
|
||||
pub fn from_env() -> Result<Self, ConfigError> {
|
||||
let api_key = std::env::var("DEEPSEEK_API_KEY")
|
||||
.map_err(|_| ConfigError::MissingEnv("DEEPSEEK_API_KEY"))?;
|
||||
let base_url = std::env::var("DEEPSEEK_BASE_URL")
|
||||
.unwrap_or_else(|_| DEFAULT_DEEPSEEK_BASE_URL.to_string());
|
||||
let model =
|
||||
std::env::var("DEEPSEEK_MODEL").unwrap_or_else(|_| DEFAULT_DEEPSEEK_MODEL.to_string());
|
||||
|
||||
if api_key.trim().is_empty() {
|
||||
return Err(ConfigError::EmptyValue("DEEPSEEK_API_KEY"));
|
||||
}
|
||||
if base_url.trim().is_empty() {
|
||||
return Err(ConfigError::EmptyValue("DEEPSEEK_BASE_URL"));
|
||||
}
|
||||
if model.trim().is_empty() {
|
||||
return Err(ConfigError::EmptyValue("DEEPSEEK_MODEL"));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
api_key,
|
||||
base_url,
|
||||
model,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Error, Clone, PartialEq, Eq)]
|
||||
pub enum ConfigError {
|
||||
#[error("missing environment variable: {0}")]
|
||||
MissingEnv(&'static str),
|
||||
#[error("environment variable must not be empty: {0}")]
|
||||
EmptyValue(&'static str),
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
pub mod agent;
|
||||
pub mod config;
|
||||
pub mod llm;
|
||||
pub mod pipe;
|
||||
pub mod security;
|
||||
|
||||
|
||||
154
src/llm/deepseek.rs
Normal file
154
src/llm/deepseek.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
use reqwest::blocking::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::config::DeepSeekSettings;
|
||||
use crate::llm::provider::{ChatMessage, LlmError, LlmProvider, ToolDefinition, ToolFunctionCall};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DeepSeekProvider {
|
||||
settings: DeepSeekSettings,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl DeepSeekProvider {
|
||||
pub fn from_env() -> Result<Self, LlmError> {
|
||||
Ok(Self::new(DeepSeekSettings::from_env()?))
|
||||
}
|
||||
|
||||
pub fn new(settings: DeepSeekSettings) -> Self {
|
||||
Self {
|
||||
settings,
|
||||
client: Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn settings(&self) -> &DeepSeekSettings {
|
||||
&self.settings
|
||||
}
|
||||
|
||||
pub fn build_chat_request(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
tools: &[ToolDefinition],
|
||||
) -> DeepSeekChatRequest {
|
||||
DeepSeekChatRequest {
|
||||
model: self.settings.model.clone(),
|
||||
messages: messages.to_vec(),
|
||||
tools: if tools.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
tools
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|tool| DeepSeekToolDefinition {
|
||||
tool_type: "function".to_string(),
|
||||
function: DeepSeekFunctionDefinition {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters,
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
},
|
||||
stream: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn completions_url(&self) -> String {
|
||||
format!(
|
||||
"{}/chat/completions",
|
||||
self.settings.base_url.trim_end_matches('/')
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl LlmProvider for DeepSeekProvider {
|
||||
fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
tools: &[ToolDefinition],
|
||||
) -> Result<Vec<ToolFunctionCall>, LlmError> {
|
||||
let response = self
|
||||
.client
|
||||
.post(self.completions_url())
|
||||
.bearer_auth(&self.settings.api_key)
|
||||
.json(&self.build_chat_request(messages, tools))
|
||||
.send()?
|
||||
.error_for_status()?
|
||||
.json::<DeepSeekChatResponse>()?;
|
||||
|
||||
let tool_calls = response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.and_then(|choice| choice.message.tool_calls)
|
||||
.ok_or(LlmError::NoToolCalls)?;
|
||||
|
||||
tool_calls
|
||||
.into_iter()
|
||||
.map(|call| {
|
||||
let arguments = serde_json::from_str(&call.function.arguments)
|
||||
.map_err(|err| LlmError::InvalidToolArguments(err.to_string()))?;
|
||||
Ok(ToolFunctionCall {
|
||||
id: call.id,
|
||||
name: call.function.name,
|
||||
arguments,
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub struct DeepSeekChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<DeepSeekToolDefinition>>,
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub struct DeepSeekToolDefinition {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
pub function: DeepSeekFunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize)]
|
||||
pub struct DeepSeekFunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeepSeekChatResponse {
|
||||
choices: Vec<DeepSeekChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeepSeekChoice {
|
||||
message: DeepSeekResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeepSeekResponseMessage {
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<DeepSeekResponseToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeepSeekResponseToolCall {
|
||||
id: String,
|
||||
function: DeepSeekResponseFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeepSeekResponseFunctionCall {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
5
src/llm/mod.rs
Normal file
5
src/llm/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod deepseek;
|
||||
mod provider;
|
||||
|
||||
pub use deepseek::{DeepSeekChatRequest, DeepSeekProvider};
|
||||
pub use provider::{ChatMessage, LlmError, LlmProvider, ToolDefinition, ToolFunctionCall};
|
||||
45
src/llm/provider.rs
Normal file
45
src/llm/provider.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct ToolDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ToolFunctionCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: Value,
|
||||
}
|
||||
|
||||
pub trait LlmProvider {
|
||||
fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
tools: &[ToolDefinition],
|
||||
) -> Result<Vec<ToolFunctionCall>, LlmError>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum LlmError {
|
||||
#[error(transparent)]
|
||||
Config(#[from] crate::config::ConfigError),
|
||||
#[error(transparent)]
|
||||
Http(#[from] reqwest::Error),
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
#[error("llm returned no tool calls")]
|
||||
NoToolCalls,
|
||||
#[error("llm returned malformed tool arguments: {0}")]
|
||||
InvalidToolArguments(String),
|
||||
}
|
||||
Reference in New Issue
Block a user