feat: refactor sgclaw around zeroclaw compat runtime
This commit is contained in:
2
third_party/zeroclaw/src/commands/mod.rs
vendored
Normal file
2
third_party/zeroclaw/src/commands/mod.rs
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod self_test;
|
||||
pub mod update;
|
||||
281
third_party/zeroclaw/src/commands/self_test.rs
vendored
Normal file
281
third_party/zeroclaw/src/commands/self_test.rs
vendored
Normal file
@@ -0,0 +1,281 @@
|
||||
//! `zeroclaw self-test` — quick and full diagnostic checks.
|
||||
|
||||
use anyhow::Result;
|
||||
use std::path::Path;
|
||||
|
||||
/// Result of a single diagnostic check.
|
||||
pub struct CheckResult {
|
||||
pub name: &'static str,
|
||||
pub passed: bool,
|
||||
pub detail: String,
|
||||
}
|
||||
|
||||
impl CheckResult {
|
||||
fn pass(name: &'static str, detail: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name,
|
||||
passed: true,
|
||||
detail: detail.into(),
|
||||
}
|
||||
}
|
||||
fn fail(name: &'static str, detail: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name,
|
||||
passed: false,
|
||||
detail: detail.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the quick self-test suite (no network required).
|
||||
pub async fn run_quick(config: &crate::config::Config) -> Result<Vec<CheckResult>> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
// 1. Config file exists and parses
|
||||
results.push(check_config(config));
|
||||
|
||||
// 2. Workspace directory is writable
|
||||
results.push(check_workspace(&config.workspace_dir).await);
|
||||
|
||||
// 3. SQLite memory backend opens
|
||||
results.push(check_sqlite(&config.workspace_dir));
|
||||
|
||||
// 4. Provider registry has entries
|
||||
results.push(check_provider_registry());
|
||||
|
||||
// 5. Tool registry has entries
|
||||
results.push(check_tool_registry(config));
|
||||
|
||||
// 6. Channel registry loads
|
||||
results.push(check_channel_config(config));
|
||||
|
||||
// 7. Security policy parses
|
||||
results.push(check_security_policy(config));
|
||||
|
||||
// 8. Version sanity
|
||||
results.push(check_version());
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Run the full self-test suite (includes network checks).
|
||||
pub async fn run_full(config: &crate::config::Config) -> Result<Vec<CheckResult>> {
|
||||
let mut results = run_quick(config).await?;
|
||||
|
||||
// 9. Gateway health endpoint
|
||||
results.push(check_gateway_health(config).await);
|
||||
|
||||
// 10. Memory write/read round-trip
|
||||
results.push(check_memory_roundtrip(config).await);
|
||||
|
||||
// 11. WebSocket handshake
|
||||
results.push(check_websocket_handshake(config).await);
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Print results in a formatted table.
|
||||
pub fn print_results(results: &[CheckResult]) {
|
||||
let total = results.len();
|
||||
let passed = results.iter().filter(|r| r.passed).count();
|
||||
let failed = total - passed;
|
||||
|
||||
println!();
|
||||
for (i, r) in results.iter().enumerate() {
|
||||
let icon = if r.passed {
|
||||
"\x1b[32m✓\x1b[0m"
|
||||
} else {
|
||||
"\x1b[31m✗\x1b[0m"
|
||||
};
|
||||
println!(" {} {}/{} {} — {}", icon, i + 1, total, r.name, r.detail);
|
||||
}
|
||||
println!();
|
||||
if failed == 0 {
|
||||
println!(" \x1b[32mAll {total} checks passed.\x1b[0m");
|
||||
} else {
|
||||
println!(" \x1b[31m{failed}/{total} checks failed.\x1b[0m");
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
fn check_config(config: &crate::config::Config) -> CheckResult {
|
||||
if config.config_path.exists() {
|
||||
CheckResult::pass(
|
||||
"config",
|
||||
format!("loaded from {}", config.config_path.display()),
|
||||
)
|
||||
} else {
|
||||
CheckResult::fail("config", "config file not found (using defaults)")
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_workspace(workspace_dir: &Path) -> CheckResult {
|
||||
match tokio::fs::metadata(workspace_dir).await {
|
||||
Ok(meta) if meta.is_dir() => {
|
||||
// Try writing a temp file
|
||||
let test_file = workspace_dir.join(".selftest_probe");
|
||||
match tokio::fs::write(&test_file, b"ok").await {
|
||||
Ok(()) => {
|
||||
let _ = tokio::fs::remove_file(&test_file).await;
|
||||
CheckResult::pass(
|
||||
"workspace",
|
||||
format!("{} (writable)", workspace_dir.display()),
|
||||
)
|
||||
}
|
||||
Err(e) => CheckResult::fail(
|
||||
"workspace",
|
||||
format!("{} (not writable: {e})", workspace_dir.display()),
|
||||
),
|
||||
}
|
||||
}
|
||||
Ok(_) => CheckResult::fail(
|
||||
"workspace",
|
||||
format!("{} exists but is not a directory", workspace_dir.display()),
|
||||
),
|
||||
Err(e) => CheckResult::fail(
|
||||
"workspace",
|
||||
format!("{} (error: {e})", workspace_dir.display()),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_sqlite(workspace_dir: &Path) -> CheckResult {
|
||||
let db_path = workspace_dir.join("memory.db");
|
||||
match rusqlite::Connection::open(&db_path) {
|
||||
Ok(conn) => match conn.execute_batch("SELECT 1") {
|
||||
Ok(()) => CheckResult::pass("sqlite", "memory.db opens and responds"),
|
||||
Err(e) => CheckResult::fail("sqlite", format!("query failed: {e}")),
|
||||
},
|
||||
Err(e) => CheckResult::fail("sqlite", format!("cannot open memory.db: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
fn check_provider_registry() -> CheckResult {
|
||||
let providers = crate::providers::list_providers();
|
||||
if providers.is_empty() {
|
||||
CheckResult::fail("providers", "no providers registered")
|
||||
} else {
|
||||
CheckResult::pass(
|
||||
"providers",
|
||||
format!("{} providers available", providers.len()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn check_tool_registry(config: &crate::config::Config) -> CheckResult {
|
||||
let security = std::sync::Arc::new(crate::security::SecurityPolicy::from_config(
|
||||
&config.autonomy,
|
||||
&config.workspace_dir,
|
||||
));
|
||||
let tools = crate::tools::default_tools(security);
|
||||
if tools.is_empty() {
|
||||
CheckResult::fail("tools", "no tools registered")
|
||||
} else {
|
||||
CheckResult::pass("tools", format!("{} core tools available", tools.len()))
|
||||
}
|
||||
}
|
||||
|
||||
fn check_channel_config(config: &crate::config::Config) -> CheckResult {
|
||||
let channels = config.channels_config.channels();
|
||||
let configured = channels.iter().filter(|(_, c)| *c).count();
|
||||
CheckResult::pass(
|
||||
"channels",
|
||||
format!(
|
||||
"{} channel types, {} configured",
|
||||
channels.len(),
|
||||
configured
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
fn check_security_policy(config: &crate::config::Config) -> CheckResult {
|
||||
let _policy =
|
||||
crate::security::SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||
CheckResult::pass(
|
||||
"security",
|
||||
format!("autonomy level: {:?}", config.autonomy.level),
|
||||
)
|
||||
}
|
||||
|
||||
fn check_version() -> CheckResult {
|
||||
let version = env!("CARGO_PKG_VERSION");
|
||||
CheckResult::pass("version", format!("v{version}"))
|
||||
}
|
||||
|
||||
async fn check_gateway_health(config: &crate::config::Config) -> CheckResult {
|
||||
let port = config.gateway.port;
|
||||
let host = if config.gateway.host == "[::]" || config.gateway.host == "0.0.0.0" {
|
||||
"127.0.0.1"
|
||||
} else {
|
||||
&config.gateway.host
|
||||
};
|
||||
let url = format!("http://{host}:{port}/health");
|
||||
match reqwest::Client::new()
|
||||
.get(&url)
|
||||
.timeout(std::time::Duration::from_secs(5))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
CheckResult::pass("gateway", format!("health OK at {url}"))
|
||||
}
|
||||
Ok(resp) => CheckResult::fail("gateway", format!("health returned {}", resp.status())),
|
||||
Err(e) => CheckResult::fail("gateway", format!("not reachable at {url}: {e}")),
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_memory_roundtrip(config: &crate::config::Config) -> CheckResult {
|
||||
let mem = match crate::memory::create_memory(
|
||||
&config.memory,
|
||||
&config.workspace_dir,
|
||||
config.api_key.as_deref(),
|
||||
) {
|
||||
Ok(m) => m,
|
||||
Err(e) => return CheckResult::fail("memory", format!("cannot create backend: {e}")),
|
||||
};
|
||||
|
||||
let test_key = "__selftest_probe__";
|
||||
let test_value = "selftest_ok";
|
||||
|
||||
if let Err(e) = mem
|
||||
.store(
|
||||
test_key,
|
||||
test_value,
|
||||
crate::memory::MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
{
|
||||
return CheckResult::fail("memory", format!("write failed: {e}"));
|
||||
}
|
||||
|
||||
match mem.recall(test_key, 1, None, None, None).await {
|
||||
Ok(entries) if !entries.is_empty() => {
|
||||
let _ = mem.forget(test_key).await;
|
||||
CheckResult::pass("memory", "write/read/delete round-trip OK")
|
||||
}
|
||||
Ok(_) => {
|
||||
let _ = mem.forget(test_key).await;
|
||||
CheckResult::fail("memory", "no entries returned after round-trip")
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = mem.forget(test_key).await;
|
||||
CheckResult::fail("memory", format!("read failed: {e}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn check_websocket_handshake(config: &crate::config::Config) -> CheckResult {
|
||||
let port = config.gateway.port;
|
||||
let host = if config.gateway.host == "[::]" || config.gateway.host == "0.0.0.0" {
|
||||
"127.0.0.1"
|
||||
} else {
|
||||
&config.gateway.host
|
||||
};
|
||||
let url = format!("ws://{host}:{port}/ws/chat");
|
||||
|
||||
match tokio_tungstenite::connect_async(&url).await {
|
||||
Ok((_, _)) => CheckResult::pass("websocket", format!("handshake OK at {url}")),
|
||||
Err(e) => CheckResult::fail("websocket", format!("handshake failed at {url}: {e}")),
|
||||
}
|
||||
}
|
||||
599
third_party/zeroclaw/src/commands/update.rs
vendored
Normal file
599
third_party/zeroclaw/src/commands/update.rs
vendored
Normal file
@@ -0,0 +1,599 @@
|
||||
//! `zeroclaw update` — self-update pipeline with rollback.
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use std::path::Path;
|
||||
use tracing::{info, warn};
|
||||
|
||||
const GITHUB_RELEASES_LATEST_URL: &str =
|
||||
"https://api.github.com/repos/zeroclaw-labs/zeroclaw/releases/latest";
|
||||
const GITHUB_RELEASES_TAG_URL: &str =
|
||||
"https://api.github.com/repos/zeroclaw-labs/zeroclaw/releases/tags";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UpdateInfo {
|
||||
pub current_version: String,
|
||||
pub latest_version: String,
|
||||
pub download_url: Option<String>,
|
||||
pub is_newer: bool,
|
||||
}
|
||||
|
||||
/// Check for available updates without downloading.
|
||||
///
|
||||
/// If `target_version` is `Some`, fetch that specific release tag instead of latest.
|
||||
pub async fn check(target_version: Option<&str>) -> Result<UpdateInfo> {
|
||||
let current = env!("CARGO_PKG_VERSION").to_string();
|
||||
|
||||
let client = reqwest::Client::builder()
|
||||
.user_agent(format!("zeroclaw/{current}"))
|
||||
.timeout(std::time::Duration::from_secs(15))
|
||||
.build()?;
|
||||
|
||||
let url = match target_version {
|
||||
Some(v) => {
|
||||
let tag = if v.starts_with('v') {
|
||||
v.to_string()
|
||||
} else {
|
||||
format!("v{v}")
|
||||
};
|
||||
format!("{GITHUB_RELEASES_TAG_URL}/{tag}")
|
||||
}
|
||||
None => GITHUB_RELEASES_LATEST_URL.to_string(),
|
||||
};
|
||||
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.context("failed to reach GitHub releases API")?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
bail!("GitHub API returned {}", resp.status());
|
||||
}
|
||||
|
||||
let release: serde_json::Value = resp.json().await?;
|
||||
let tag = release["tag_name"]
|
||||
.as_str()
|
||||
.unwrap_or("unknown")
|
||||
.trim_start_matches('v')
|
||||
.to_string();
|
||||
|
||||
let download_url = find_asset_url(&release);
|
||||
let is_newer = version_is_newer(¤t, &tag);
|
||||
|
||||
Ok(UpdateInfo {
|
||||
current_version: current,
|
||||
latest_version: tag,
|
||||
download_url,
|
||||
is_newer,
|
||||
})
|
||||
}
|
||||
|
||||
/// Run the full 6-phase update pipeline.
|
||||
///
|
||||
/// If `target_version` is `Some`, fetch that specific version instead of latest.
|
||||
pub async fn run(target_version: Option<&str>) -> Result<()> {
|
||||
// Phase 1: Preflight
|
||||
info!("Phase 1/6: Preflight checks...");
|
||||
let update_info = check(target_version).await?;
|
||||
|
||||
if !update_info.is_newer {
|
||||
println!("Already up to date (v{}).", update_info.current_version);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!(
|
||||
"Update available: v{} -> v{}",
|
||||
update_info.current_version, update_info.latest_version
|
||||
);
|
||||
|
||||
let download_url = update_info
|
||||
.download_url
|
||||
.context("no suitable binary found for this platform")?;
|
||||
|
||||
let current_exe =
|
||||
std::env::current_exe().context("cannot determine current executable path")?;
|
||||
|
||||
// Phase 2: Download
|
||||
info!("Phase 2/6: Downloading...");
|
||||
let temp_dir = tempfile::tempdir().context("failed to create temp dir")?;
|
||||
let download_path = temp_dir.path().join("zeroclaw_new");
|
||||
download_binary(&download_url, &download_path).await?;
|
||||
|
||||
// Phase 3: Backup
|
||||
info!("Phase 3/6: Creating backup...");
|
||||
let backup_path = current_exe.with_extension("bak");
|
||||
tokio::fs::copy(¤t_exe, &backup_path)
|
||||
.await
|
||||
.context("failed to backup current binary")?;
|
||||
|
||||
// Phase 4: Validate
|
||||
info!("Phase 4/6: Validating download...");
|
||||
validate_binary(&download_path).await?;
|
||||
|
||||
// Phase 5: Swap
|
||||
info!("Phase 5/6: Swapping binary...");
|
||||
if let Err(e) = swap_binary(&download_path, ¤t_exe).await {
|
||||
// Rollback
|
||||
warn!("Swap failed, rolling back: {e}");
|
||||
if let Err(rollback_err) = rollback_binary(&backup_path, ¤t_exe).await {
|
||||
eprintln!("CRITICAL: Rollback also failed: {rollback_err}");
|
||||
eprintln!(
|
||||
"Manual recovery: cp {} {}",
|
||||
backup_path.display(),
|
||||
current_exe.display()
|
||||
);
|
||||
}
|
||||
bail!("Update failed during swap: {e}");
|
||||
}
|
||||
|
||||
// Phase 6: Smoke test
|
||||
info!("Phase 6/6: Smoke test...");
|
||||
match smoke_test(¤t_exe).await {
|
||||
Ok(()) => {
|
||||
// Cleanup backup on success
|
||||
let _ = tokio::fs::remove_file(&backup_path).await;
|
||||
println!("Successfully updated to v{}!", update_info.latest_version);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Smoke test failed, rolling back: {e}");
|
||||
rollback_binary(&backup_path, ¤t_exe)
|
||||
.await
|
||||
.context("rollback after smoke test failure")?;
|
||||
bail!("Update rolled back — smoke test failed: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn find_asset_url(release: &serde_json::Value) -> Option<String> {
|
||||
let target = current_target_triple();
|
||||
|
||||
release["assets"]
|
||||
.as_array()?
|
||||
.iter()
|
||||
.find(|asset| {
|
||||
asset["name"]
|
||||
.as_str()
|
||||
.map(|name| name.contains(target))
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.and_then(|asset| asset["browser_download_url"].as_str().map(String::from))
|
||||
}
|
||||
|
||||
/// Return the exact Rust target triple for the current platform.
|
||||
///
|
||||
/// Using full triples (e.g. `aarch64-unknown-linux-gnu` instead of the
|
||||
/// shorter `aarch64-unknown-linux`) prevents substring matches from
|
||||
/// selecting the wrong asset (e.g. an Android binary on a GNU/Linux host).
|
||||
fn current_target_triple() -> &'static str {
|
||||
if cfg!(target_os = "macos") {
|
||||
if cfg!(target_arch = "aarch64") {
|
||||
"aarch64-apple-darwin"
|
||||
} else {
|
||||
"x86_64-apple-darwin"
|
||||
}
|
||||
} else if cfg!(target_os = "linux") {
|
||||
if cfg!(target_arch = "aarch64") {
|
||||
"aarch64-unknown-linux-gnu"
|
||||
} else {
|
||||
"x86_64-unknown-linux-gnu"
|
||||
}
|
||||
} else {
|
||||
"unknown"
|
||||
}
|
||||
}
|
||||
|
||||
fn version_is_newer(current: &str, candidate: &str) -> bool {
|
||||
let parse = |v: &str| -> Vec<u32> { v.split('.').filter_map(|p| p.parse().ok()).collect() };
|
||||
let cur = parse(current);
|
||||
let cand = parse(candidate);
|
||||
cand > cur
|
||||
}
|
||||
|
||||
async fn download_binary(url: &str, dest: &Path) -> Result<()> {
|
||||
let client = reqwest::Client::builder()
|
||||
.user_agent(format!("zeroclaw/{}", env!("CARGO_PKG_VERSION")))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.build()?;
|
||||
|
||||
let resp = client
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.context("download request failed")?;
|
||||
if !resp.status().is_success() {
|
||||
bail!("download returned {}", resp.status());
|
||||
}
|
||||
|
||||
let bytes = resp.bytes().await.context("failed to read download body")?;
|
||||
|
||||
// Release assets are .tar.gz archives containing a single `zeroclaw` binary.
|
||||
// Extract the binary from the archive instead of writing the raw tarball.
|
||||
if url.ends_with(".tar.gz") || url.ends_with(".tgz") {
|
||||
extract_tar_gz(&bytes, dest).context("failed to extract binary from tar.gz archive")?;
|
||||
} else {
|
||||
tokio::fs::write(dest, &bytes)
|
||||
.await
|
||||
.context("failed to write downloaded binary")?;
|
||||
}
|
||||
|
||||
// Make executable on Unix
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let perms = std::fs::Permissions::from_mode(0o755);
|
||||
tokio::fs::set_permissions(dest, perms).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extract the `zeroclaw` binary from a `.tar.gz` archive.
|
||||
fn extract_tar_gz(archive_bytes: &[u8], dest: &Path) -> Result<()> {
|
||||
use flate2::read::GzDecoder;
|
||||
use std::io::Read;
|
||||
use tar::Archive;
|
||||
|
||||
let gz = GzDecoder::new(archive_bytes);
|
||||
let mut archive = Archive::new(gz);
|
||||
|
||||
for entry in archive.entries().context("failed to read tar entries")? {
|
||||
let mut entry = entry.context("failed to read tar entry")?;
|
||||
let path = entry.path().context("failed to read entry path")?;
|
||||
|
||||
// The archive contains a single binary named "zeroclaw" (or "zeroclaw.exe" on Windows).
|
||||
let file_name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
|
||||
|
||||
if file_name == "zeroclaw" || file_name == "zeroclaw.exe" {
|
||||
let mut buf = Vec::new();
|
||||
entry
|
||||
.read_to_end(&mut buf)
|
||||
.context("failed to read binary from archive")?;
|
||||
std::fs::write(dest, &buf).context("failed to write extracted binary")?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
bail!("archive does not contain a 'zeroclaw' binary")
|
||||
}
|
||||
|
||||
async fn validate_binary(path: &Path) -> Result<()> {
|
||||
let meta = tokio::fs::metadata(path).await?;
|
||||
if meta.len() < 1_000_000 {
|
||||
bail!(
|
||||
"downloaded binary too small ({} bytes), likely corrupt",
|
||||
meta.len()
|
||||
);
|
||||
}
|
||||
|
||||
// Check binary architecture before attempting execution so we can give
|
||||
// a clear diagnostic instead of the opaque "Exec format error (os error 8)".
|
||||
check_binary_arch(path).await?;
|
||||
|
||||
// Quick check: try running --version
|
||||
let output = tokio::process::Command::new(path)
|
||||
.arg("--version")
|
||||
.output()
|
||||
.await
|
||||
.context("cannot execute downloaded binary")?;
|
||||
|
||||
if !output.status.success() {
|
||||
bail!("downloaded binary --version check failed");
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
if !stdout.contains("zeroclaw") {
|
||||
bail!("downloaded binary does not appear to be zeroclaw");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Read the binary header and verify its architecture matches the host.
|
||||
///
|
||||
/// On Linux/FreeBSD this reads the ELF header; on macOS the Mach-O header.
|
||||
/// If the binary is for a different architecture, returns a descriptive error
|
||||
/// instead of the opaque "Exec format error (os error 8)".
|
||||
async fn check_binary_arch(path: &Path) -> Result<()> {
|
||||
let header = tokio::fs::read(path)
|
||||
.await
|
||||
.map(|bytes| bytes.into_iter().take(32).collect::<Vec<u8>>())
|
||||
.context("failed to read binary header")?;
|
||||
|
||||
if header.len() < 20 {
|
||||
bail!("downloaded file too small to be a valid binary");
|
||||
}
|
||||
|
||||
let binary_arch = detect_arch_from_header(&header);
|
||||
let host_arch = host_architecture();
|
||||
|
||||
if let (Some(bin), Some(host)) = (binary_arch, host_arch) {
|
||||
if bin != host {
|
||||
bail!(
|
||||
"architecture mismatch: downloaded binary is {bin} but this host is {host} — \
|
||||
the release asset may be mispackaged"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Detect the CPU architecture from an ELF or Mach-O binary header.
|
||||
fn detect_arch_from_header(header: &[u8]) -> Option<&'static str> {
|
||||
// ELF magic: 0x7f 'E' 'L' 'F'
|
||||
if header.len() >= 20 && header[0..4] == [0x7f, b'E', b'L', b'F'] {
|
||||
// e_machine is at offset 18 (2 bytes, little-endian for LE binaries)
|
||||
let e_machine = u16::from_le_bytes([header[18], header[19]]);
|
||||
return Some(match e_machine {
|
||||
0x3E => "x86_64",
|
||||
0xB7 => "aarch64",
|
||||
0x03 => "x86",
|
||||
0x28 => "arm",
|
||||
0xF3 => "riscv",
|
||||
_ => "unknown-elf",
|
||||
});
|
||||
}
|
||||
|
||||
// Mach-O magic (64-bit little-endian): 0xFEEDFACF
|
||||
if header.len() >= 8 && header[0..4] == [0xCF, 0xFA, 0xED, 0xFE] {
|
||||
let cputype = u32::from_le_bytes([header[4], header[5], header[6], header[7]]);
|
||||
return Some(match cputype {
|
||||
0x0100_0007 => "x86_64",
|
||||
0x0100_000C => "aarch64",
|
||||
_ => "unknown-macho",
|
||||
});
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Return the host CPU architecture as a human-readable string.
|
||||
fn host_architecture() -> Option<&'static str> {
|
||||
if cfg!(target_arch = "x86_64") {
|
||||
Some("x86_64")
|
||||
} else if cfg!(target_arch = "aarch64") {
|
||||
Some("aarch64")
|
||||
} else if cfg!(target_arch = "x86") {
|
||||
Some("x86")
|
||||
} else if cfg!(target_arch = "arm") {
|
||||
Some("arm")
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
async fn swap_binary(new: &Path, target: &Path) -> Result<()> {
|
||||
// On Linux, a running binary cannot be overwritten in place (ETXTBSY).
|
||||
// Remove the old file first, then copy the new one into the now-free path.
|
||||
// This works because the kernel keeps the inode alive until the process exits.
|
||||
tokio::fs::remove_file(target)
|
||||
.await
|
||||
.context("failed to remove old binary")?;
|
||||
tokio::fs::copy(new, target)
|
||||
.await
|
||||
.context("failed to write new binary")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn rollback_binary(backup: &Path, target: &Path) -> Result<()> {
|
||||
// Remove-then-copy to avoid ETXTBSY if the target is somehow still mapped.
|
||||
let _ = tokio::fs::remove_file(target).await;
|
||||
tokio::fs::copy(backup, target)
|
||||
.await
|
||||
.context("failed to restore backup binary")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn smoke_test(binary: &Path) -> Result<()> {
|
||||
let output = tokio::process::Command::new(binary)
|
||||
.arg("--version")
|
||||
.output()
|
||||
.await
|
||||
.context("smoke test: cannot execute updated binary")?;
|
||||
|
||||
if !output.status.success() {
|
||||
bail!("smoke test: updated binary returned non-zero exit code");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_version_comparison() {
|
||||
assert!(version_is_newer("0.4.3", "0.5.0"));
|
||||
assert!(version_is_newer("0.4.3", "0.4.4"));
|
||||
assert!(!version_is_newer("0.5.0", "0.4.3"));
|
||||
assert!(!version_is_newer("0.4.3", "0.4.3"));
|
||||
assert!(version_is_newer("1.0.0", "2.0.0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn current_target_triple_is_not_empty() {
|
||||
let triple = current_target_triple();
|
||||
assert_ne!(triple, "unknown", "unsupported platform");
|
||||
// The triple must contain at least two hyphens (arch-vendor-os or arch-vendor-os-env)
|
||||
assert!(
|
||||
triple.matches('-').count() >= 2,
|
||||
"triple should have at least two hyphens: {triple}"
|
||||
);
|
||||
}
|
||||
|
||||
fn make_release(assets: &[&str]) -> serde_json::Value {
|
||||
let assets: Vec<serde_json::Value> = assets
|
||||
.iter()
|
||||
.map(|name| {
|
||||
serde_json::json!({
|
||||
"name": name,
|
||||
"browser_download_url": format!("https://example.com/{name}")
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
serde_json::json!({ "assets": assets })
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_asset_url_picks_correct_gnu_over_android() {
|
||||
let release = make_release(&[
|
||||
"zeroclaw-aarch64-linux-android.tar.gz",
|
||||
"zeroclaw-aarch64-unknown-linux-gnu.tar.gz",
|
||||
"zeroclaw-x86_64-unknown-linux-gnu.tar.gz",
|
||||
"zeroclaw-x86_64-apple-darwin.tar.gz",
|
||||
"zeroclaw-aarch64-apple-darwin.tar.gz",
|
||||
]);
|
||||
|
||||
let url = find_asset_url(&release);
|
||||
assert!(url.is_some(), "should find an asset");
|
||||
let url = url.unwrap();
|
||||
// Must NOT match the android binary
|
||||
assert!(
|
||||
!url.contains("android"),
|
||||
"should not select android binary, got: {url}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_asset_url_returns_none_for_empty_assets() {
|
||||
let release = serde_json::json!({ "assets": [] });
|
||||
assert!(find_asset_url(&release).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_asset_url_returns_none_for_missing_assets() {
|
||||
let release = serde_json::json!({});
|
||||
assert!(find_asset_url(&release).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_arch_elf_x86_64() {
|
||||
// Minimal ELF header with e_machine = 0x3E (x86_64)
|
||||
let mut header = vec![0u8; 20];
|
||||
header[0..4].copy_from_slice(&[0x7f, b'E', b'L', b'F']);
|
||||
header[18] = 0x3E;
|
||||
header[19] = 0x00;
|
||||
assert_eq!(detect_arch_from_header(&header), Some("x86_64"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_arch_elf_aarch64() {
|
||||
let mut header = vec![0u8; 20];
|
||||
header[0..4].copy_from_slice(&[0x7f, b'E', b'L', b'F']);
|
||||
header[18] = 0xB7;
|
||||
header[19] = 0x00;
|
||||
assert_eq!(detect_arch_from_header(&header), Some("aarch64"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_arch_macho_x86_64() {
|
||||
// Mach-O 64-bit LE magic + cputype 0x01000007 (x86_64)
|
||||
let mut header = vec![0u8; 8];
|
||||
header[0..4].copy_from_slice(&[0xCF, 0xFA, 0xED, 0xFE]);
|
||||
header[4..8].copy_from_slice(&0x0100_0007u32.to_le_bytes());
|
||||
assert_eq!(detect_arch_from_header(&header), Some("x86_64"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_arch_macho_aarch64() {
|
||||
let mut header = vec![0u8; 8];
|
||||
header[0..4].copy_from_slice(&[0xCF, 0xFA, 0xED, 0xFE]);
|
||||
header[4..8].copy_from_slice(&0x0100_000Cu32.to_le_bytes());
|
||||
assert_eq!(detect_arch_from_header(&header), Some("aarch64"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_arch_unknown_format() {
|
||||
let header = vec![0u8; 20]; // all zeros — not ELF or Mach-O
|
||||
assert_eq!(detect_arch_from_header(&header), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn detect_arch_too_short() {
|
||||
let header = vec![0x7f, b'E', b'L', b'F']; // only 4 bytes
|
||||
assert_eq!(detect_arch_from_header(&header), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn host_architecture_is_known() {
|
||||
assert!(
|
||||
host_architecture().is_some(),
|
||||
"host architecture should be detected on CI platforms"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tar_gz_finds_binary() {
|
||||
use flate2::write::GzEncoder;
|
||||
use flate2::Compression;
|
||||
use std::io::Write;
|
||||
|
||||
// Build a tar.gz in memory containing a fake "zeroclaw" binary.
|
||||
let fake_binary = b"#!/bin/sh\necho zeroclaw";
|
||||
let mut tar_buf = Vec::new();
|
||||
{
|
||||
let mut builder = tar::Builder::new(&mut tar_buf);
|
||||
let mut header = tar::Header::new_gnu();
|
||||
header.set_size(fake_binary.len() as u64);
|
||||
header.set_mode(0o755);
|
||||
header.set_cksum();
|
||||
builder
|
||||
.append_data(&mut header, "zeroclaw", &fake_binary[..])
|
||||
.unwrap();
|
||||
builder.finish().unwrap();
|
||||
}
|
||||
|
||||
let mut gz_buf = Vec::new();
|
||||
{
|
||||
let mut encoder = GzEncoder::new(&mut gz_buf, Compression::fast());
|
||||
encoder.write_all(&tar_buf).unwrap();
|
||||
encoder.finish().unwrap();
|
||||
}
|
||||
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let dest = tmp.path().join("zeroclaw_extracted");
|
||||
extract_tar_gz(&gz_buf, &dest).unwrap();
|
||||
|
||||
let content = std::fs::read(&dest).unwrap();
|
||||
assert_eq!(content, fake_binary);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tar_gz_errors_on_missing_binary() {
|
||||
use flate2::write::GzEncoder;
|
||||
use flate2::Compression;
|
||||
use std::io::Write;
|
||||
|
||||
// Build a tar.gz with a file that is NOT named "zeroclaw".
|
||||
let mut tar_buf = Vec::new();
|
||||
{
|
||||
let mut builder = tar::Builder::new(&mut tar_buf);
|
||||
let mut header = tar::Header::new_gnu();
|
||||
header.set_size(5);
|
||||
header.set_mode(0o644);
|
||||
header.set_cksum();
|
||||
builder
|
||||
.append_data(&mut header, "README.md", &b"hello"[..])
|
||||
.unwrap();
|
||||
builder.finish().unwrap();
|
||||
}
|
||||
|
||||
let mut gz_buf = Vec::new();
|
||||
{
|
||||
let mut encoder = GzEncoder::new(&mut gz_buf, Compression::fast());
|
||||
encoder.write_all(&tar_buf).unwrap();
|
||||
encoder.finish().unwrap();
|
||||
}
|
||||
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let dest = tmp.path().join("zeroclaw_extracted");
|
||||
let result = extract_tar_gz(&gz_buf, &dest);
|
||||
assert!(result.is_err());
|
||||
assert!(
|
||||
result.unwrap_err().to_string().contains("does not contain"),
|
||||
"should report missing binary"
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user