From a4e366a35670930397a878912ffdfa30ef8c2762 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 27 Dec 2024 14:19:27 +0100 Subject: [PATCH] Custom progressbar. Using a trait without fat pointer. (#80) * Tmp. Custom progressbar. Allow internal state. More API. Clippy. Remove print statements Using stream feature for shorter lived progressbars. Moved muliplexing behind less obvious builder (needs more testing to showcase pros and cons). * Fixing after rebase. * Showcase the progressbar with iced example. * Fix the example. --- Cargo.toml | 3 +- README.md | 2 +- examples/download.rs | 7 +- examples/iced/.gitignore | 1 + examples/iced/Cargo.toml | 8 ++ examples/iced/src/main.rs | 244 ++++++++++++++++++++++++++++++++++++++ src/api/mod.rs | 47 ++++++++ src/api/sync.rs | 169 +++++++++++++++----------- src/api/tokio.rs | 208 +++++++++++++++++++++++--------- 9 files changed, 558 insertions(+), 131 deletions(-) create mode 100644 examples/iced/.gitignore create mode 100644 examples/iced/Cargo.toml create mode 100644 examples/iced/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index c24367e..fce8734 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ num_cpus = { version = "1.15.0", optional = true } rand = { version = "0.8.5", optional = true } reqwest = { version = "0.12.2", optional = true, default-features = false, features = [ "json", + "stream", ] } rustls = { version = "0.23.4", optional = true } serde = { version = "1", features = ["derive"], optional = true } @@ -41,7 +42,7 @@ native-tls = { version = "0.2.12", optional = true } [features] default = ["default-tls", "tokio", "ureq"] # These features are only relevant when used with the `tokio` feature, but this might change in the future. -default-tls = [] +default-tls = ["native-tls"] native-tls = ["dep:reqwest", "reqwest/default", "dep:native-tls", "ureq/native-tls"] rustls-tls = ["dep:rustls", "reqwest/rustls-tls"] tokio = [ diff --git a/README.md b/README.md index 6876c46..6d0b0d9 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ let _filename = repo.get("config.json").unwrap(); # SSL/TLS -This library uses its dependencies' default TLS implementations which are `rustls` for `ureq` (sync) and `native-tls` (openssl) for `tokio`. +This library uses tokio default TLS implementations which is `native-tls` (openssl) for `tokio`. If you want control over the TLS backend you can remove the default features and only add the backend you are intending to use. diff --git a/examples/download.rs b/examples/download.rs index 7aea0c6..c59a663 100644 --- a/examples/download.rs +++ b/examples/download.rs @@ -1,10 +1,9 @@ #[cfg(not(feature = "ureq"))] -#[cfg(not(feature="tokio"))] -fn main() { -} +#[cfg(not(feature = "tokio"))] +fn main() {} #[cfg(feature = "ureq")] -#[cfg(not(feature="tokio"))] +#[cfg(not(feature = "tokio"))] fn main() { let api = hf_hub::api::sync::Api::new().unwrap(); diff --git a/examples/iced/.gitignore b/examples/iced/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/examples/iced/.gitignore @@ -0,0 +1 @@ +/target diff --git a/examples/iced/Cargo.toml b/examples/iced/Cargo.toml new file mode 100644 index 0000000..86bf96b --- /dev/null +++ b/examples/iced/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "iced_hf_hub" +version = "0.1.0" +edition = "2021" + +[dependencies] +iced = { version = "0.13.1", features = ["tokio"] } +hf-hub = { path = "../../", default-features = false, features = ["tokio", "rustls-tls"] } diff --git a/examples/iced/src/main.rs b/examples/iced/src/main.rs new file mode 100644 index 0000000..a9e0d43 --- /dev/null +++ b/examples/iced/src/main.rs @@ -0,0 +1,244 @@ +use hf_hub::api::tokio::{Api, ApiError}; +use iced::futures::{SinkExt, Stream}; +use iced::stream::try_channel; +use iced::task; +use iced::widget::{button, center, column, progress_bar, text, Column}; + +use iced::{Center, Element, Right, Task}; + +#[derive(Debug, Clone)] +pub enum Progress { + Downloading { current: usize, total: usize }, + Finished, +} + +#[derive(Debug, Clone)] +pub enum Error { + Api(String), +} + +impl From for Error { + fn from(value: ApiError) -> Self { + Self::Api(value.to_string()) + } +} + +pub fn main() -> iced::Result { + iced::application("Download Progress - Iced", Example::update, Example::view).run() +} + +#[derive(Debug)] +struct Example { + downloads: Vec, + last_id: usize, +} + +#[derive(Clone)] +struct Prog { + output: iced::futures::channel::mpsc::Sender, + total: usize, +} + +impl hf_hub::api::tokio::Progress for Prog { + async fn update(&mut self, size: usize) { + let _ = self + .output + .send(Progress::Downloading { + current: size, + total: self.total, + }) + .await; + } + async fn finish(&mut self) { + let _ = self.output.send(Progress::Finished).await; + } + + async fn init(&mut self, size: usize, _filename: &str) { + println!("Initiating {size}"); + let _ = self + .output + .send(Progress::Downloading { + current: 0, + total: size, + }) + .await; + self.total = size; + } +} + +pub fn download( + repo: String, + filename: impl AsRef, +) -> impl Stream> { + try_channel(1, move |output| async move { + let prog = Prog { output, total: 0 }; + + let api = Api::new().unwrap().model(repo); + api.download_with_progress(filename.as_ref(), prog).await?; + + Ok(()) + }) +} + +#[derive(Debug, Clone)] +pub enum Message { + Add, + Download(usize), + DownloadProgressed(usize, Result), +} + +impl Example { + fn new() -> Self { + Self { + downloads: vec![Download::new(0)], + last_id: 0, + } + } + + fn update(&mut self, message: Message) -> Task { + match message { + Message::Add => { + self.last_id += 1; + + self.downloads.push(Download::new(self.last_id)); + + Task::none() + } + Message::Download(index) => { + let Some(download) = self.downloads.get_mut(index) else { + return Task::none(); + }; + + let task = download.start(); + + task.map(move |progress| Message::DownloadProgressed(index, progress)) + } + Message::DownloadProgressed(id, progress) => { + if let Some(download) = self.downloads.iter_mut().find(|download| download.id == id) + { + download.progress(progress); + } + + Task::none() + } + } + } + + fn view(&self) -> Element { + let downloads = Column::with_children(self.downloads.iter().map(Download::view)) + .push( + button("Add another download") + .on_press(Message::Add) + .padding(10), + ) + .spacing(20) + .align_x(Right); + + center(downloads).padding(20).into() + } +} + +impl Default for Example { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug)] +struct Download { + id: usize, + state: State, +} + +#[derive(Debug)] +enum State { + Idle, + Downloading { progress: f32, _task: task::Handle }, + Finished, + Errored, +} + +impl Download { + pub fn new(id: usize) -> Self { + Download { + id, + state: State::Idle, + } + } + + pub fn start(&mut self) -> Task> { + match self.state { + State::Idle { .. } | State::Finished { .. } | State::Errored { .. } => { + let (task, handle) = Task::stream(download( + "mattshumer/Reflection-Llama-3.1-70B".to_string(), + "model-00001-of-00162.safetensors", + )) + .abortable(); + + self.state = State::Downloading { + progress: 0.0, + _task: handle.abort_on_drop(), + }; + + task + } + State::Downloading { .. } => Task::none(), + } + } + + pub fn progress(&mut self, new_progress: Result) { + if let State::Downloading { progress, .. } = &mut self.state { + match new_progress { + Ok(Progress::Downloading { current, total }) => { + println!("Status {progress} - {current}"); + let new_progress = current as f32 / total as f32 * 100.0; + println!("New progress {current} {new_progress}"); + *progress += new_progress; + } + Ok(Progress::Finished) => { + self.state = State::Finished; + } + Err(_error) => { + self.state = State::Errored; + } + } + } + } + + pub fn view(&self) -> Element { + let current_progress = match &self.state { + State::Idle { .. } => 0.0, + State::Downloading { progress, .. } => *progress, + State::Finished { .. } => 100.0, + State::Errored { .. } => 0.0, + }; + + let progress_bar = progress_bar(0.0..=100.0, current_progress); + + let control: Element<_> = match &self.state { + State::Idle => button("Start the download!") + .on_press(Message::Download(self.id)) + .into(), + State::Finished => column!["Download finished!", button("Start again")] + .spacing(10) + .align_x(Center) + .into(), + State::Downloading { .. } => text!("Downloading... {current_progress:.2}%").into(), + State::Errored => column![ + "Something went wrong :(", + button("Try again").on_press(Message::Download(self.id)), + ] + .spacing(10) + .align_x(Center) + .into(), + }; + + Column::new() + .spacing(10) + .padding(10) + .align_x(Center) + .push(progress_bar) + .push(control) + .into() + } +} diff --git a/src/api/mod.rs b/src/api/mod.rs index ef738ca..a5bc6a4 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,3 +1,4 @@ +use indicatif::{ProgressBar, ProgressStyle}; use serde::Deserialize; /// The asynchronous version of the API @@ -8,6 +9,52 @@ pub mod tokio; #[cfg(feature = "ureq")] pub mod sync; +/// This trait is used by users of the lib +/// to implement custom behavior during file downloads +pub trait Progress { + /// At the start of the download + /// The size is the total size in bytes of the file. + fn init(&mut self, size: usize, filename: &str); + /// This function is called whenever `size` bytes have been + /// downloaded in the temporary file + fn update(&mut self, size: usize); + /// This is called at the end of the download + fn finish(&mut self); +} + +impl Progress for () { + fn init(&mut self, _size: usize, _filename: &str) {} + fn update(&mut self, _size: usize) {} + fn finish(&mut self) {} +} + +impl Progress for ProgressBar { + fn init(&mut self, size: usize, filename: &str) { + self.set_length(size as u64); + self.set_style( + ProgressStyle::with_template( + "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", + ) + .unwrap(), // .progress_chars("━ "), + ); + let maxlength = 30; + let message = if filename.len() > maxlength { + format!("..{}", &filename[filename.len() - maxlength..]) + } else { + filename.to_string() + }; + self.set_message(message); + } + + fn update(&mut self, size: usize) { + self.inc(size as u64) + } + + fn finish(&mut self) { + ProgressBar::finish(self); + } +} + /// Siblings are simplified file descriptions of remote files on the hub #[derive(Debug, Clone, Deserialize, PartialEq)] pub struct Siblings { diff --git a/src/api/sync.rs b/src/api/sync.rs index 31d627f..a0f96b2 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -1,10 +1,12 @@ use super::RepoInfo; use crate::api::sync::ApiError::InvalidHeader; +use crate::api::Progress; use crate::{Cache, Repo, RepoType}; use http::{StatusCode, Uri}; -use indicatif::{ProgressBar, ProgressStyle}; +use indicatif::ProgressBar; use rand::Rng; use std::collections::HashMap; +use std::io::Read; use std::io::Seek; use std::num::ParseIntError; use std::path::{Component, Path, PathBuf}; @@ -26,6 +28,23 @@ const AUTHORIZATION: &str = "Authorization"; type HeaderMap = HashMap<&'static str, String>; type HeaderName = &'static str; +struct Wrapper<'a, P: Progress, R: Read> { + progress: &'a mut P, + inner: R, +} + +fn wrap_read(inner: R, progress: &mut P) -> Wrapper { + Wrapper { inner, progress } +} + +impl Read for Wrapper<'_, P, R> { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let read = self.inner.read(buf)?; + self.progress.update(read); + Ok(read) + } +} + /// Simple wrapper over [`ureq::Agent`] to include default headers #[derive(Clone, Debug)] pub struct HeaderAgent { @@ -92,7 +111,6 @@ pub enum ApiError { pub struct ApiBuilder { endpoint: String, cache: Cache, - url_template: String, token: Option, max_retries: usize, progress: bool, @@ -128,12 +146,10 @@ impl ApiBuilder { let max_retries = 0; let progress = true; - let endpoint = - std::env::var("HF_ENDPOINT").unwrap_or_else(|_| "https://huggingface.co".to_owned()); + let endpoint = "https://huggingface.co".to_string(); Self { endpoint, - url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), cache, token, max_retries, @@ -197,10 +213,8 @@ impl ApiBuilder { Ok(Api { endpoint: self.endpoint, - url_template: self.url_template, cache: self.cache, client, - no_redirect_client, max_retries: self.max_retries, progress: self.progress, @@ -216,12 +230,10 @@ struct Metadata { } /// The actual Api used to interacto with the hub. -/// You can inspect repos with [`Api::info`] -/// or download files with [`Api::download`] +/// Use any repo with [`Api::repo`] #[derive(Clone, Debug)] pub struct Api { endpoint: String, - url_template: String, cache: Cache, client: HeaderAgent, no_redirect_client: HeaderAgent, @@ -402,57 +414,47 @@ impl Api { }) } - fn download_tempfile( + fn download_tempfile( &self, url: &str, - progressbar: Option, + mut progress: P, + filename: &str, ) -> Result { - let filename = self.cache.temp_path(); + let filepath = self.cache.temp_path(); // Create the file and set everything properly - let mut file = std::fs::File::create(&filename)?; + let mut file = std::fs::File::create(&filepath)?; + let mut res = self.download_from(url, 0u64, &mut file, filename, &mut progress); if self.max_retries > 0 { let mut i = 0; - let mut res = self.download_from(url, 0u64, &mut file); while let Err(dlerr) = res { let wait_time = exponential_backoff(300, i, 10_000); std::thread::sleep(std::time::Duration::from_millis(wait_time as u64)); - res = self.download_from(url, file.stream_position()?, &mut file); + let current = file.stream_position()?; + res = self.download_from(url, current, &mut file, filename, &mut progress); i += 1; if i > self.max_retries { return Err(ApiError::TooManyRetries(dlerr.into())); } } - res?; - if let Some(p) = progressbar { - p.finish() - } - return Ok(filename); } - - let response = self.client.get(url).call().map_err(Box::new)?; - - let mut reader = response.into_reader(); - if let Some(p) = &progressbar { - reader = Box::new(p.wrap_read(reader)); - } - - std::io::copy(&mut reader, &mut file)?; - - if let Some(p) = progressbar { - p.finish(); - } - Ok(filename) + res?; + Ok(filepath) } - fn download_from( + fn download_from

( &self, url: &str, current: u64, file: &mut std::fs::File, - ) -> Result<(), ApiError> { + filename: &str, + progress: &mut P, + ) -> Result<(), ApiError> + where + P: Progress, + { let range = format!("bytes={current}-"); let response = self .client @@ -460,8 +462,12 @@ impl Api { .set(RANGE, &range) .call() .map_err(Box::new)?; - let mut reader = response.into_reader(); + let reader = response.into_reader(); + progress.init(0, filename); + progress.update(current as usize); + let mut reader = Box::new(wrap_read(reader, progress)); std::io::copy(&mut reader, file)?; + progress.finish(); Ok(()) } @@ -506,6 +512,8 @@ impl Api { } /// Shorthand for accessing things within a particular repo +/// You can inspect repos with [`ApiRepo::info`] +/// or download files with [`ApiRepo::download`] #[derive(Debug)] pub struct ApiRepo { api: Api, @@ -541,12 +549,8 @@ impl ApiRepo { pub fn url(&self, filename: &str) -> String { let endpoint = &self.api.endpoint; let revision = &self.repo.url_revision(); - self.api - .url_template - .replace("{endpoint}", endpoint) - .replace("{repo_id}", &self.repo.url()) - .replace("{revision}", revision) - .replace("{filename}", filename) + let repo_id = self.repo.url(); + format!("{endpoint}/{repo_id}/resolve/{revision}/{filename}") } /// This will attempt the fetch the file locally first, then [`Api.download`] @@ -563,16 +567,40 @@ impl ApiRepo { } } - /// Downloads a remote file (if not already present) into the cache directory - /// to be used locally. - /// This functions require internet access to verify if new versions of the file - /// exist, even if a file is already on disk at location. + /// This function is used to download a file with a custom progress function. + /// It uses the [`Progress`] trait and can be used in more complex use + /// cases like downloading a showing progress in a UI. /// ```no_run - /// # use hf_hub::api::sync::Api; + /// # use hf_hub::api::{sync::Api, Progress}; + /// struct MyProgress{ + /// current: usize, + /// total: usize + /// } + /// + /// impl Progress for MyProgress{ + /// fn init(&mut self, size: usize, _filename: &str){ + /// self.total = size; + /// self.current = 0; + /// } + /// + /// fn update(&mut self, size: usize){ + /// self.current += size; + /// println!("{}/{}", self.current, self.total) + /// } + /// + /// fn finish(&mut self){ + /// println!("Done !"); + /// } + /// } /// let api = Api::new().unwrap(); - /// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").unwrap(); + /// let progress = MyProgress{current: 0, total: 0}; + /// let local_filename = api.model("gpt2".to_string()).download_with_progress("model.safetensors", progress).unwrap(); /// ``` - pub fn download(&self, filename: &str) -> Result { + pub fn download_with_progress( + &self, + filename: &str, + mut progress: P, + ) -> Result { let url = self.url(filename); let metadata = self.api.metadata(&url)?; @@ -583,27 +611,9 @@ impl ApiRepo { .blob_path(&metadata.etag); std::fs::create_dir_all(blob_path.parent().unwrap())?; - let progressbar = if self.api.progress { - let progress = ProgressBar::new(metadata.size as u64); - progress.set_style( - ProgressStyle::with_template( - "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", - ) - .unwrap(), // .progress_chars("━ "), - ); - let maxlength = 30; - let message = if filename.len() > maxlength { - format!("..{}", &filename[filename.len() - maxlength..]) - } else { - filename.to_string() - }; - progress.set_message(message); - Some(progress) - } else { - None - }; + progress.init(metadata.size, filename); - let tmp_filename = self.api.download_tempfile(&url, progressbar)?; + let tmp_filename = self.api.download_tempfile(&url, progress, filename)?; std::fs::rename(tmp_filename, &blob_path)?; @@ -624,6 +634,23 @@ impl ApiRepo { Ok(pointer_path) } + /// Downloads a remote file (if not already present) into the cache directory + /// to be used locally. + /// This functions require internet access to verify if new versions of the file + /// exist, even if a file is already on disk at location. + /// ```no_run + /// # use hf_hub::api::sync::Api; + /// let api = Api::new().unwrap(); + /// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").unwrap(); + /// ``` + pub fn download(&self, filename: &str) -> Result { + if self.api.progress { + self.download_with_progress(filename, ProgressBar::new(0)) + } else { + self.download_with_progress(filename, ()) + } + } + /// Get information about the Repo /// ``` /// use hf_hub::{api::sync::Api}; diff --git a/src/api/tokio.rs b/src/api/tokio.rs index b3f4959..15d06f9 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -1,6 +1,8 @@ +use super::Progress as SyncProgress; use super::RepoInfo; use crate::{Cache, Repo, RepoType}; -use indicatif::{ProgressBar, ProgressStyle}; +use futures::StreamExt; +use indicatif::ProgressBar; use rand::Rng; use reqwest::{ header::{ @@ -22,6 +24,38 @@ const VERSION: &str = env!("CARGO_PKG_VERSION"); /// Current name (used in user-agent) const NAME: &str = env!("CARGO_PKG_NAME"); +/// This trait is used by users of the lib +/// to implement custom behavior during file downloads +pub trait Progress { + /// At the start of the download + /// The size is the total size in bytes of the file. + fn init(&mut self, size: usize, filename: &str) + -> impl std::future::Future + Send; + /// This function is called whenever `size` bytes have been + /// downloaded in the temporary file + fn update(&mut self, size: usize) -> impl std::future::Future + Send; + /// This is called at the end of the download + fn finish(&mut self) -> impl std::future::Future + Send; +} + +impl Progress for ProgressBar { + async fn init(&mut self, size: usize, filename: &str) { + ::init(self, size, filename); + } + async fn finish(&mut self) { + ::finish(self); + } + async fn update(&mut self, size: usize) { + ::update(self, size); + } +} + +impl Progress for () { + async fn init(&mut self, _size: usize, _filename: &str) {} + async fn finish(&mut self) {} + async fn update(&mut self, _size: usize) {} +} + #[derive(Debug, Error)] /// All errors the API can throw pub enum ApiError { @@ -74,10 +108,9 @@ pub enum ApiError { pub struct ApiBuilder { endpoint: String, cache: Cache, - url_template: String, token: Option, max_files: usize, - chunk_size: usize, + chunk_size: Option, parallel_failures: usize, max_retries: usize, progress: bool, @@ -100,6 +133,23 @@ impl ApiBuilder { Self::from_cache(cache) } + /// High CPU download + /// + /// This may cause issues on regular desktops as it will saturate + /// CPUs by multiplexing the downloads. + /// However on high CPU machines on the cloud, this may help + /// saturate the bandwidth (>500MB/s) better. + /// ``` + /// use hf_hub::api::tokio::ApiBuilder; + /// let api = ApiBuilder::high().build().unwrap(); + /// ``` + pub fn high() -> Self { + let cache = Cache::default(); + Self::from_cache(cache) + .with_max_files(num_cpus::get()) + .with_chunk_size(Some(10_000_000)) + } + /// From a given cache /// ``` /// use hf_hub::{api::tokio::ApiBuilder, Cache}; @@ -114,11 +164,11 @@ impl ApiBuilder { Self { endpoint: "https://huggingface.co".to_string(), - url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), cache, token, - max_files: num_cpus::get(), - chunk_size: 10_000_000, + max_files: 1, + // chunk_size: 10_000_000, + chunk_size: None, parallel_failures: 0, max_retries: 0, progress, @@ -130,7 +180,7 @@ impl ApiBuilder { self.progress = progress; self } - + /// Changes the endpoint of the API. Default is `https://huggingface.co`. pub fn with_endpoint(mut self, endpoint: String) -> Self { self.endpoint = endpoint; @@ -149,6 +199,18 @@ impl ApiBuilder { self } + /// Sets the number of open files + pub fn with_max_files(mut self, max_files: usize) -> Self { + self.max_files = max_files; + self + } + + /// Sets the size of each chunk + pub fn with_chunk_size(mut self, chunk_size: Option) -> Self { + self.chunk_size = chunk_size; + self + } + fn build_headers(&self) -> Result { let mut headers = HeaderMap::new(); let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown"); @@ -192,7 +254,6 @@ impl ApiBuilder { .build()?; Ok(Api { endpoint: self.endpoint, - url_template: self.url_template, cache: self.cache, client, relative_redirect_client, @@ -213,17 +274,15 @@ struct Metadata { } /// The actual Api used to interact with the hub. -/// You can inspect repos with [`Api::info`] -/// or download files with [`Api::download`] +/// Use any repo with [`Api::repo`] #[derive(Clone, Debug)] pub struct Api { endpoint: String, - url_template: String, cache: Cache, client: Client, relative_redirect_client: Client, max_files: usize, - chunk_size: usize, + chunk_size: Option, parallel_failures: usize, max_retries: usize, progress: bool, @@ -400,6 +459,8 @@ impl Api { } /// Shorthand for accessing things within a particular repo +/// You can inspect repos with [`ApiRepo::info`] +/// or download files with [`ApiRepo::download`] #[derive(Debug)] pub struct ApiRepo { api: Api, @@ -423,19 +484,15 @@ impl ApiRepo { pub fn url(&self, filename: &str) -> String { let endpoint = &self.api.endpoint; let revision = &self.repo.url_revision(); - self.api - .url_template - .replace("{endpoint}", endpoint) - .replace("{repo_id}", &self.repo.url()) - .replace("{revision}", revision) - .replace("{filename}", filename) + let repo_id = self.repo.url(); + format!("{endpoint}/{repo_id}/resolve/{revision}/{filename}") } - async fn download_tempfile( + async fn download_tempfile<'a, P: Progress + Clone + Send + Sync + 'static>( &self, url: &str, length: usize, - progressbar: Option, + mut progressbar: P, ) -> Result { let mut handles = vec![]; let semaphore = Arc::new(Semaphore::new(self.api.max_files)); @@ -448,7 +505,7 @@ impl ApiRepo { .set_len(length as u64) .await?; - let chunk_size = self.api.chunk_size; + let chunk_size = self.api.chunk_size.unwrap_or(length); for start in (0..length).step_by(chunk_size) { let url = url.to_string(); let filename = filename.clone(); @@ -461,7 +518,9 @@ impl ApiRepo { let parallel_failures_semaphore = parallel_failures_semaphore.clone(); let progress = progressbar.clone(); handles.push(tokio::spawn(async move { - let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; + let mut chunk = + Self::download_chunk(&client, &url, &filename, start, stop, progress.clone()) + .await; let mut i = 0; if parallel_failures > 0 { while let Err(dlerr) = chunk { @@ -472,7 +531,15 @@ impl ApiRepo { tokio::time::sleep(tokio::time::Duration::from_millis(wait_time as u64)) .await; - chunk = Self::download_chunk(&client, &url, &filename, start, stop).await; + chunk = Self::download_chunk( + &client, + &url, + &filename, + start, + stop, + progress.clone(), + ) + .await; i += 1; if i > max_retries { return Err(ApiError::TooManyRetries(dlerr.into())); @@ -481,9 +548,9 @@ impl ApiRepo { } } drop(permit); - if let Some(p) = progress { - p.inc((stop - start) as u64); - } + // if let Some(p) = progress { + // progress.update(stop - start).await; + // } chunk })); } @@ -493,19 +560,21 @@ impl ApiRepo { futures::future::join_all(handles).await; let results: Result<(), ApiError> = results.into_iter().flatten().collect(); results?; - if let Some(p) = progressbar { - p.finish(); - } + progressbar.finish().await; Ok(filename) } - async fn download_chunk( + async fn download_chunk

( client: &reqwest::Client, url: &str, filename: &PathBuf, start: usize, stop: usize, - ) -> Result<(), ApiError> { + mut progress: P, + ) -> Result<(), ApiError> + where + P: Progress, + { // Process each socket concurrently. let range = format!("bytes={start}-{stop}"); let mut file = tokio::fs::OpenOptions::new() @@ -519,8 +588,12 @@ impl ApiRepo { .send() .await? .error_for_status()?; - let content = response.bytes().await?; - file.write_all(&content).await?; + let mut byte_stream = response.bytes_stream(); + while let Some(next) = byte_stream.next().await { + let next = next?; + file.write_all(&next).await?; + progress.update(next.len()).await; + } Ok(()) } @@ -552,6 +625,52 @@ impl ApiRepo { /// # }) /// ``` pub async fn download(&self, filename: &str) -> Result { + if self.api.progress { + self.download_with_progress(filename, ProgressBar::new(0)) + .await + } else { + self.download_with_progress(filename, ()).await + } + } + + /// This function is used to download a file with a custom progress function. + /// It uses the [`Progress`] trait and can be used in more complex use + /// cases like downloading a showing progress in a UI. + /// ```no_run + /// use hf_hub::api::tokio::{Api, Progress}; + /// + /// #[derive(Clone)] + /// struct MyProgress{ + /// current: usize, + /// total: usize + /// } + /// + /// impl Progress for MyProgress{ + /// async fn init(&mut self, size: usize, _filename: &str){ + /// self.total = size; + /// self.current = 0; + /// } + /// + /// async fn update(&mut self, size: usize){ + /// self.current += size; + /// println!("{}/{}", self.current, self.total) + /// } + /// + /// async fn finish(&mut self){ + /// println!("Done !"); + /// } + /// } + /// # tokio_test::block_on(async { + /// let api = Api::new().unwrap(); + /// let progress = MyProgress{ current: 0, total : 0}; + /// let local_filename = api.model("gpt2".to_string()).download_with_progress("model.safetensors", progress).await.unwrap(); + /// # }) + /// ``` + pub async fn download_with_progress( + &self, + filename: &str, + mut progress: P, + ) -> Result { let url = self.url(filename); let metadata = self.api.metadata(&url).await?; let cache = self.api.cache.repo(self.repo.clone()); @@ -559,28 +678,9 @@ impl ApiRepo { let blob_path = cache.blob_path(&metadata.etag); std::fs::create_dir_all(blob_path.parent().unwrap())?; - let progressbar = if self.api.progress { - let progress = ProgressBar::new(metadata.size as u64); - progress.set_style( - ProgressStyle::with_template( - "{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})", - ) - .unwrap(), // .progress_chars("━ "), - ); - let maxlength = 30; - let message = if filename.len() > maxlength { - format!("..{}", &filename[filename.len() - maxlength..]) - } else { - filename.to_string() - }; - progress.set_message(message); - Some(progress) - } else { - None - }; - + progress.init(metadata.size, filename).await; let tmp_filename = self - .download_tempfile(&url, metadata.size, progressbar) + .download_tempfile(&url, metadata.size, progress) .await?; tokio::fs::rename(&tmp_filename, &blob_path).await?;