Skip to content

Commit

Permalink
Custom progressbar. Using a trait without fat pointer. (#80)
Browse files Browse the repository at this point in the history
* Tmp.

Custom progressbar.

Allow internal state.

More API.

Clippy.

Remove print statements

Using stream feature for shorter lived progressbars.

Moved muliplexing behind less obvious builder (needs more testing to
showcase pros and cons).

* Fixing after rebase.

* Showcase the progressbar with iced example.

* Fix the example.
  • Loading branch information
Narsil authored Dec 27, 2024
1 parent 206b344 commit a4e366a
Show file tree
Hide file tree
Showing 9 changed files with 558 additions and 131 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ num_cpus = { version = "1.15.0", optional = true }
rand = { version = "0.8.5", optional = true }
reqwest = { version = "0.12.2", optional = true, default-features = false, features = [
"json",
"stream",
] }
rustls = { version = "0.23.4", optional = true }
serde = { version = "1", features = ["derive"], optional = true }
Expand All @@ -41,7 +42,7 @@ native-tls = { version = "0.2.12", optional = true }
[features]
default = ["default-tls", "tokio", "ureq"]
# These features are only relevant when used with the `tokio` feature, but this might change in the future.
default-tls = []
default-tls = ["native-tls"]
native-tls = ["dep:reqwest", "reqwest/default", "dep:native-tls", "ureq/native-tls"]
rustls-tls = ["dep:rustls", "reqwest/rustls-tls"]
tokio = [
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ let _filename = repo.get("config.json").unwrap();

# SSL/TLS

This library uses its dependencies' default TLS implementations which are `rustls` for `ureq` (sync) and `native-tls` (openssl) for `tokio`.
This library uses tokio default TLS implementations which is `native-tls` (openssl) for `tokio`.

If you want control over the TLS backend you can remove the default features and only add the backend you are intending to use.

Expand Down
7 changes: 3 additions & 4 deletions examples/download.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#[cfg(not(feature = "ureq"))]
#[cfg(not(feature="tokio"))]
fn main() {
}
#[cfg(not(feature = "tokio"))]
fn main() {}

#[cfg(feature = "ureq")]
#[cfg(not(feature="tokio"))]
#[cfg(not(feature = "tokio"))]
fn main() {
let api = hf_hub::api::sync::Api::new().unwrap();

Expand Down
1 change: 1 addition & 0 deletions examples/iced/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/target
8 changes: 8 additions & 0 deletions examples/iced/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[package]
name = "iced_hf_hub"
version = "0.1.0"
edition = "2021"

[dependencies]
iced = { version = "0.13.1", features = ["tokio"] }
hf-hub = { path = "../../", default-features = false, features = ["tokio", "rustls-tls"] }
244 changes: 244 additions & 0 deletions examples/iced/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
use hf_hub::api::tokio::{Api, ApiError};
use iced::futures::{SinkExt, Stream};
use iced::stream::try_channel;
use iced::task;
use iced::widget::{button, center, column, progress_bar, text, Column};

use iced::{Center, Element, Right, Task};

#[derive(Debug, Clone)]
pub enum Progress {
Downloading { current: usize, total: usize },
Finished,
}

#[derive(Debug, Clone)]
pub enum Error {
Api(String),
}

impl From<ApiError> for Error {
fn from(value: ApiError) -> Self {
Self::Api(value.to_string())
}
}

pub fn main() -> iced::Result {
iced::application("Download Progress - Iced", Example::update, Example::view).run()
}

#[derive(Debug)]
struct Example {
downloads: Vec<Download>,
last_id: usize,
}

#[derive(Clone)]
struct Prog {
output: iced::futures::channel::mpsc::Sender<Progress>,
total: usize,
}

impl hf_hub::api::tokio::Progress for Prog {
async fn update(&mut self, size: usize) {
let _ = self
.output
.send(Progress::Downloading {
current: size,
total: self.total,
})
.await;
}
async fn finish(&mut self) {
let _ = self.output.send(Progress::Finished).await;
}

async fn init(&mut self, size: usize, _filename: &str) {
println!("Initiating {size}");
let _ = self
.output
.send(Progress::Downloading {
current: 0,
total: size,
})
.await;
self.total = size;
}
}

pub fn download(
repo: String,
filename: impl AsRef<str>,
) -> impl Stream<Item = Result<Progress, Error>> {
try_channel(1, move |output| async move {
let prog = Prog { output, total: 0 };

let api = Api::new().unwrap().model(repo);
api.download_with_progress(filename.as_ref(), prog).await?;

Ok(())
})
}

#[derive(Debug, Clone)]
pub enum Message {
Add,
Download(usize),
DownloadProgressed(usize, Result<Progress, Error>),
}

impl Example {
fn new() -> Self {
Self {
downloads: vec![Download::new(0)],
last_id: 0,
}
}

fn update(&mut self, message: Message) -> Task<Message> {
match message {
Message::Add => {
self.last_id += 1;

self.downloads.push(Download::new(self.last_id));

Task::none()
}
Message::Download(index) => {
let Some(download) = self.downloads.get_mut(index) else {
return Task::none();
};

let task = download.start();

task.map(move |progress| Message::DownloadProgressed(index, progress))
}
Message::DownloadProgressed(id, progress) => {
if let Some(download) = self.downloads.iter_mut().find(|download| download.id == id)
{
download.progress(progress);
}

Task::none()
}
}
}

fn view(&self) -> Element<Message> {
let downloads = Column::with_children(self.downloads.iter().map(Download::view))
.push(
button("Add another download")
.on_press(Message::Add)
.padding(10),
)
.spacing(20)
.align_x(Right);

center(downloads).padding(20).into()
}
}

impl Default for Example {
fn default() -> Self {
Self::new()
}
}

#[derive(Debug)]
struct Download {
id: usize,
state: State,
}

#[derive(Debug)]
enum State {
Idle,
Downloading { progress: f32, _task: task::Handle },
Finished,
Errored,
}

impl Download {
pub fn new(id: usize) -> Self {
Download {
id,
state: State::Idle,
}
}

pub fn start(&mut self) -> Task<Result<Progress, Error>> {
match self.state {
State::Idle { .. } | State::Finished { .. } | State::Errored { .. } => {
let (task, handle) = Task::stream(download(
"mattshumer/Reflection-Llama-3.1-70B".to_string(),
"model-00001-of-00162.safetensors",
))
.abortable();

self.state = State::Downloading {
progress: 0.0,
_task: handle.abort_on_drop(),
};

task
}
State::Downloading { .. } => Task::none(),
}
}

pub fn progress(&mut self, new_progress: Result<Progress, Error>) {
if let State::Downloading { progress, .. } = &mut self.state {
match new_progress {
Ok(Progress::Downloading { current, total }) => {
println!("Status {progress} - {current}");
let new_progress = current as f32 / total as f32 * 100.0;
println!("New progress {current} {new_progress}");
*progress += new_progress;
}
Ok(Progress::Finished) => {
self.state = State::Finished;
}
Err(_error) => {
self.state = State::Errored;
}
}
}
}

pub fn view(&self) -> Element<Message> {
let current_progress = match &self.state {
State::Idle { .. } => 0.0,
State::Downloading { progress, .. } => *progress,
State::Finished { .. } => 100.0,
State::Errored { .. } => 0.0,
};

let progress_bar = progress_bar(0.0..=100.0, current_progress);

let control: Element<_> = match &self.state {
State::Idle => button("Start the download!")
.on_press(Message::Download(self.id))
.into(),
State::Finished => column!["Download finished!", button("Start again")]
.spacing(10)
.align_x(Center)
.into(),
State::Downloading { .. } => text!("Downloading... {current_progress:.2}%").into(),
State::Errored => column![
"Something went wrong :(",
button("Try again").on_press(Message::Download(self.id)),
]
.spacing(10)
.align_x(Center)
.into(),
};

Column::new()
.spacing(10)
.padding(10)
.align_x(Center)
.push(progress_bar)
.push(control)
.into()
}
}
47 changes: 47 additions & 0 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use indicatif::{ProgressBar, ProgressStyle};
use serde::Deserialize;

/// The asynchronous version of the API
Expand All @@ -8,6 +9,52 @@ pub mod tokio;
#[cfg(feature = "ureq")]
pub mod sync;

/// This trait is used by users of the lib
/// to implement custom behavior during file downloads
pub trait Progress {
/// At the start of the download
/// The size is the total size in bytes of the file.
fn init(&mut self, size: usize, filename: &str);
/// This function is called whenever `size` bytes have been
/// downloaded in the temporary file
fn update(&mut self, size: usize);
/// This is called at the end of the download
fn finish(&mut self);
}

impl Progress for () {
fn init(&mut self, _size: usize, _filename: &str) {}
fn update(&mut self, _size: usize) {}
fn finish(&mut self) {}
}

impl Progress for ProgressBar {
fn init(&mut self, size: usize, filename: &str) {
self.set_length(size as u64);
self.set_style(
ProgressStyle::with_template(
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
)
.unwrap(), // .progress_chars("━ "),
);
let maxlength = 30;
let message = if filename.len() > maxlength {
format!("..{}", &filename[filename.len() - maxlength..])
} else {
filename.to_string()
};
self.set_message(message);
}

fn update(&mut self, size: usize) {
self.inc(size as u64)
}

fn finish(&mut self) {
ProgressBar::finish(self);
}
}

/// Siblings are simplified file descriptions of remote files on the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct Siblings {
Expand Down
Loading

0 comments on commit a4e366a

Please sign in to comment.