Skip to content

Commit

Permalink
Feat/sync retry download (#23)
Browse files Browse the repository at this point in the history
* Add max_retires field to sync api

* Add basic retry capability to sync download

* Add simple test for download with retries
without simulating a failure

* Use stream position when continuing download
and remove ResumableReader

* Fix conflict.

---------

Co-authored-by: Nicolas Patry <[email protected]>
  • Loading branch information
benedikt-schaber and Narsil authored Dec 25, 2024
1 parent 1c21bc6 commit edda880
Showing 1 changed file with 87 additions and 0 deletions.
87 changes: 87 additions & 0 deletions src/api/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use crate::api::sync::ApiError::InvalidHeader;
use crate::{Cache, Repo, RepoType};
use http::{StatusCode, Uri};
use indicatif::{ProgressBar, ProgressStyle};
use rand::Rng;
use std::collections::HashMap;
use std::io::Seek;
use std::num::ParseIntError;
use std::path::{Component, Path, PathBuf};
use std::str::FromStr;
Expand Down Expand Up @@ -92,6 +94,7 @@ pub struct ApiBuilder {
cache: Cache,
url_template: String,
token: Option<String>,
max_retries: usize,
progress: bool,
}

Expand Down Expand Up @@ -122,6 +125,7 @@ impl ApiBuilder {
pub fn from_cache(cache: Cache) -> Self {
let token = cache.token();

let max_retries = 0;
let progress = true;

let endpoint =
Expand All @@ -132,6 +136,7 @@ impl ApiBuilder {
url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(),
cache,
token,
max_retries,
progress,
}
}
Expand Down Expand Up @@ -160,6 +165,12 @@ impl ApiBuilder {
self
}

/// Sets the number of times the API will retry to download a file
pub fn with_retries(mut self, max_retries: usize) -> Self {
self.max_retries = max_retries;
self
}

fn build_headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown");
Expand Down Expand Up @@ -191,6 +202,7 @@ impl ApiBuilder {
client,

no_redirect_client,
max_retries: self.max_retries,
progress: self.progress,
})
}
Expand All @@ -213,6 +225,7 @@ pub struct Api {
cache: Cache,
client: HeaderAgent,
no_redirect_client: HeaderAgent,
max_retries: usize,
progress: bool,
}

Expand Down Expand Up @@ -270,6 +283,14 @@ fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> {
Ok(())
}

fn jitter() -> usize {
rand::thread_rng().gen_range(0..=500)
}

fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize {
(base_wait_time + n.pow(2) + jitter()).min(max)
}

impl Api {
/// Creates a default Api, for Api options See [`ApiBuilder`]
pub fn new() -> Result<Self, ApiError> {
Expand Down Expand Up @@ -391,6 +412,26 @@ impl Api {
// Create the file and set everything properly
let mut file = std::fs::File::create(&filename)?;

if self.max_retries > 0 {
let mut i = 0;
let mut res = self.download_from(url, 0u64, &mut file);
while let Err(dlerr) = res {
let wait_time = exponential_backoff(300, i, 10_000);
std::thread::sleep(std::time::Duration::from_millis(wait_time as u64));

res = self.download_from(url, file.stream_position()?, &mut file);
i += 1;
if i > self.max_retries {
return Err(ApiError::TooManyRetries(dlerr.into()));
}
}
res?;
if let Some(p) = progressbar {
p.finish()
}
return Ok(filename);
}

let response = self.client.get(url).call().map_err(Box::new)?;

let mut reader = response.into_reader();
Expand All @@ -406,6 +447,24 @@ impl Api {
Ok(filename)
}

fn download_from(
&self,
url: &str,
current: u64,
file: &mut std::fs::File,
) -> Result<(), ApiError> {
let range = format!("bytes={current}-");
let response = self
.client
.get(url)
.set(RANGE, &range)
.call()
.map_err(Box::new)?;
let mut reader = response.into_reader();
std::io::copy(&mut reader, file)?;
Ok(())
}

/// Creates a new handle [`ApiRepo`] which contains operations
/// on a particular [`Repo`]
pub fn repo(&self, repo: Repo) -> ApiRepo {
Expand Down Expand Up @@ -651,6 +710,34 @@ mod tests {
assert_eq!(cache_path, downloaded_path);
}

#[test]
fn simple_with_retries() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.with_retries(3)
.build()
.unwrap();

let model_id = "julien-c/dummy-unknown".to_string();
let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
val[..],
hex!("b908f2b7227d4d31a2105dfa31095e28d304f9bc938bfaaa57ee2cacf1f62d32")
);

// Make sure the file is now seeable without connection
let cache_path = api
.cache
.repo(Repo::new(model_id, RepoType::Model))
.get("config.json")
.unwrap();
assert_eq!(cache_path, downloaded_path);
}

#[test]
fn dataset() {
let tmp = TempDir::new();
Expand Down

0 comments on commit edda880

Please sign in to comment.