Files
skill-lib/src/llm/deepseek.rs
2026-03-25 04:24:59 +00:00

155 lines
4.1 KiB
Rust

use reqwest::blocking::Client;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::config::DeepSeekSettings;
use crate::llm::provider::{ChatMessage, LlmError, LlmProvider, ToolDefinition, ToolFunctionCall};
#[derive(Debug, Clone)]
pub struct DeepSeekProvider {
settings: DeepSeekSettings,
client: Client,
}
impl DeepSeekProvider {
pub fn from_env() -> Result<Self, LlmError> {
Ok(Self::new(DeepSeekSettings::from_env()?))
}
pub fn new(settings: DeepSeekSettings) -> Self {
Self {
settings,
client: Client::new(),
}
}
pub fn settings(&self) -> &DeepSeekSettings {
&self.settings
}
pub fn build_chat_request(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
) -> DeepSeekChatRequest {
DeepSeekChatRequest {
model: self.settings.model.clone(),
messages: messages.to_vec(),
tools: if tools.is_empty() {
None
} else {
Some(
tools
.iter()
.cloned()
.map(|tool| DeepSeekToolDefinition {
tool_type: "function".to_string(),
function: DeepSeekFunctionDefinition {
name: tool.name,
description: tool.description,
parameters: tool.parameters,
},
})
.collect(),
)
},
stream: false,
}
}
fn completions_url(&self) -> String {
format!(
"{}/chat/completions",
self.settings.base_url.trim_end_matches('/')
)
}
}
impl LlmProvider for DeepSeekProvider {
fn chat(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
) -> Result<Vec<ToolFunctionCall>, LlmError> {
let response = self
.client
.post(self.completions_url())
.bearer_auth(&self.settings.api_key)
.json(&self.build_chat_request(messages, tools))
.send()?
.error_for_status()?
.json::<DeepSeekChatResponse>()?;
let tool_calls = response
.choices
.into_iter()
.next()
.and_then(|choice| choice.message.tool_calls)
.ok_or(LlmError::NoToolCalls)?;
tool_calls
.into_iter()
.map(|call| {
let arguments = serde_json::from_str(&call.function.arguments)
.map_err(|err| LlmError::InvalidToolArguments(err.to_string()))?;
Ok(ToolFunctionCall {
id: call.id,
name: call.function.name,
arguments,
})
})
.collect()
}
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct DeepSeekChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<DeepSeekToolDefinition>>,
pub stream: bool,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct DeepSeekToolDefinition {
#[serde(rename = "type")]
pub tool_type: String,
pub function: DeepSeekFunctionDefinition,
}
#[derive(Debug, Clone, PartialEq, Serialize)]
pub struct DeepSeekFunctionDefinition {
pub name: String,
pub description: String,
pub parameters: Value,
}
#[derive(Debug, Deserialize)]
struct DeepSeekChatResponse {
choices: Vec<DeepSeekChoice>,
}
#[derive(Debug, Deserialize)]
struct DeepSeekChoice {
message: DeepSeekResponseMessage,
}
#[derive(Debug, Deserialize)]
struct DeepSeekResponseMessage {
#[serde(default)]
tool_calls: Option<Vec<DeepSeekResponseToolCall>>,
}
#[derive(Debug, Deserialize)]
struct DeepSeekResponseToolCall {
id: String,
function: DeepSeekResponseFunctionCall,
}
#[derive(Debug, Deserialize)]
struct DeepSeekResponseFunctionCall {
name: String,
arguments: String,
}