155 lines
4.1 KiB
Rust
155 lines
4.1 KiB
Rust
use reqwest::blocking::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
|
|
use crate::config::DeepSeekSettings;
|
|
use crate::llm::provider::{ChatMessage, LlmError, LlmProvider, ToolDefinition, ToolFunctionCall};
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct DeepSeekProvider {
|
|
settings: DeepSeekSettings,
|
|
client: Client,
|
|
}
|
|
|
|
impl DeepSeekProvider {
|
|
pub fn from_env() -> Result<Self, LlmError> {
|
|
Ok(Self::new(DeepSeekSettings::from_env()?))
|
|
}
|
|
|
|
pub fn new(settings: DeepSeekSettings) -> Self {
|
|
Self {
|
|
settings,
|
|
client: Client::new(),
|
|
}
|
|
}
|
|
|
|
pub fn settings(&self) -> &DeepSeekSettings {
|
|
&self.settings
|
|
}
|
|
|
|
pub fn build_chat_request(
|
|
&self,
|
|
messages: &[ChatMessage],
|
|
tools: &[ToolDefinition],
|
|
) -> DeepSeekChatRequest {
|
|
DeepSeekChatRequest {
|
|
model: self.settings.model.clone(),
|
|
messages: messages.to_vec(),
|
|
tools: if tools.is_empty() {
|
|
None
|
|
} else {
|
|
Some(
|
|
tools
|
|
.iter()
|
|
.cloned()
|
|
.map(|tool| DeepSeekToolDefinition {
|
|
tool_type: "function".to_string(),
|
|
function: DeepSeekFunctionDefinition {
|
|
name: tool.name,
|
|
description: tool.description,
|
|
parameters: tool.parameters,
|
|
},
|
|
})
|
|
.collect(),
|
|
)
|
|
},
|
|
stream: false,
|
|
}
|
|
}
|
|
|
|
fn completions_url(&self) -> String {
|
|
format!(
|
|
"{}/chat/completions",
|
|
self.settings.base_url.trim_end_matches('/')
|
|
)
|
|
}
|
|
}
|
|
|
|
impl LlmProvider for DeepSeekProvider {
|
|
fn chat(
|
|
&self,
|
|
messages: &[ChatMessage],
|
|
tools: &[ToolDefinition],
|
|
) -> Result<Vec<ToolFunctionCall>, LlmError> {
|
|
let response = self
|
|
.client
|
|
.post(self.completions_url())
|
|
.bearer_auth(&self.settings.api_key)
|
|
.json(&self.build_chat_request(messages, tools))
|
|
.send()?
|
|
.error_for_status()?
|
|
.json::<DeepSeekChatResponse>()?;
|
|
|
|
let tool_calls = response
|
|
.choices
|
|
.into_iter()
|
|
.next()
|
|
.and_then(|choice| choice.message.tool_calls)
|
|
.ok_or(LlmError::NoToolCalls)?;
|
|
|
|
tool_calls
|
|
.into_iter()
|
|
.map(|call| {
|
|
let arguments = serde_json::from_str(&call.function.arguments)
|
|
.map_err(|err| LlmError::InvalidToolArguments(err.to_string()))?;
|
|
Ok(ToolFunctionCall {
|
|
id: call.id,
|
|
name: call.function.name,
|
|
arguments,
|
|
})
|
|
})
|
|
.collect()
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Serialize)]
|
|
pub struct DeepSeekChatRequest {
|
|
pub model: String,
|
|
pub messages: Vec<ChatMessage>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tools: Option<Vec<DeepSeekToolDefinition>>,
|
|
pub stream: bool,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Serialize)]
|
|
pub struct DeepSeekToolDefinition {
|
|
#[serde(rename = "type")]
|
|
pub tool_type: String,
|
|
pub function: DeepSeekFunctionDefinition,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Serialize)]
|
|
pub struct DeepSeekFunctionDefinition {
|
|
pub name: String,
|
|
pub description: String,
|
|
pub parameters: Value,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct DeepSeekChatResponse {
|
|
choices: Vec<DeepSeekChoice>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct DeepSeekChoice {
|
|
message: DeepSeekResponseMessage,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct DeepSeekResponseMessage {
|
|
#[serde(default)]
|
|
tool_calls: Option<Vec<DeepSeekResponseToolCall>>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct DeepSeekResponseToolCall {
|
|
id: String,
|
|
function: DeepSeekResponseFunctionCall,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct DeepSeekResponseFunctionCall {
|
|
name: String,
|
|
arguments: String,
|
|
}
|