Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom progressbar. Using a trait without fat pointer. #80

Merged
merged 4 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ num_cpus = { version = "1.15.0", optional = true }
rand = { version = "0.8.5", optional = true }
reqwest = { version = "0.12.2", optional = true, default-features = false, features = [
"json",
"stream",
] }
rustls = { version = "0.23.4", optional = true }
serde = { version = "1", features = ["derive"], optional = true }
Expand All @@ -41,7 +42,7 @@ native-tls = { version = "0.2.12", optional = true }
[features]
default = ["default-tls", "tokio", "ureq"]
# These features are only relevant when used with the `tokio` feature, but this might change in the future.
default-tls = []
default-tls = ["native-tls"]
native-tls = ["dep:reqwest", "reqwest/default", "dep:native-tls", "ureq/native-tls"]
rustls-tls = ["dep:rustls", "reqwest/rustls-tls"]
tokio = [
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ let _filename = repo.get("config.json").unwrap();

# SSL/TLS

This library uses its dependencies' default TLS implementations which are `rustls` for `ureq` (sync) and `native-tls` (openssl) for `tokio`.
This library uses tokio default TLS implementations which is `native-tls` (openssl) for `tokio`.

If you want control over the TLS backend you can remove the default features and only add the backend you are intending to use.

Expand Down
7 changes: 3 additions & 4 deletions examples/download.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#[cfg(not(feature = "ureq"))]
#[cfg(not(feature="tokio"))]
fn main() {
}
#[cfg(not(feature = "tokio"))]
fn main() {}

#[cfg(feature = "ureq")]
#[cfg(not(feature="tokio"))]
#[cfg(not(feature = "tokio"))]
fn main() {
let api = hf_hub::api::sync::Api::new().unwrap();

Expand Down
1 change: 1 addition & 0 deletions examples/iced/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/target
8 changes: 8 additions & 0 deletions examples/iced/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[package]
name = "iced_hf_hub"
version = "0.1.0"
edition = "2021"

[dependencies]
iced = { version = "0.13.1", features = ["tokio"] }
hf-hub = { path = "../../", default-features = false, features = ["tokio", "rustls-tls"] }
244 changes: 244 additions & 0 deletions examples/iced/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
use hf_hub::api::tokio::{Api, ApiError};
use iced::futures::{SinkExt, Stream};
use iced::stream::try_channel;
use iced::task;
use iced::widget::{button, center, column, progress_bar, text, Column};

use iced::{Center, Element, Right, Task};

#[derive(Debug, Clone)]
pub enum Progress {
Downloading { current: usize, total: usize },
Finished,
}

#[derive(Debug, Clone)]
pub enum Error {
Api(String),
}

impl From<ApiError> for Error {
fn from(value: ApiError) -> Self {
Self::Api(value.to_string())
}
}

pub fn main() -> iced::Result {
iced::application("Download Progress - Iced", Example::update, Example::view).run()
}

#[derive(Debug)]
struct Example {
downloads: Vec<Download>,
last_id: usize,
}

#[derive(Clone)]
struct Prog {
output: iced::futures::channel::mpsc::Sender<Progress>,
total: usize,
}

impl hf_hub::api::tokio::Progress for Prog {
async fn update(&mut self, size: usize) {
let _ = self
.output
.send(Progress::Downloading {
current: size,
total: self.total,
})
.await;
}
async fn finish(&mut self) {
let _ = self.output.send(Progress::Finished).await;
}

async fn init(&mut self, size: usize, _filename: &str) {
println!("Initiating {size}");
let _ = self
.output
.send(Progress::Downloading {
current: 0,
total: size,
})
.await;
self.total = size;
}
}

pub fn download(
repo: String,
filename: impl AsRef<str>,
) -> impl Stream<Item = Result<Progress, Error>> {
try_channel(1, move |output| async move {
let prog = Prog { output, total: 0 };

let api = Api::new().unwrap().model(repo);
api.download_with_progress(filename.as_ref(), prog).await?;

Ok(())
})
}

#[derive(Debug, Clone)]
pub enum Message {
Add,
Download(usize),
DownloadProgressed(usize, Result<Progress, Error>),
}

impl Example {
fn new() -> Self {
Self {
downloads: vec![Download::new(0)],
last_id: 0,
}
}

fn update(&mut self, message: Message) -> Task<Message> {
match message {
Message::Add => {
self.last_id += 1;

self.downloads.push(Download::new(self.last_id));

Task::none()
}
Message::Download(index) => {
let Some(download) = self.downloads.get_mut(index) else {
return Task::none();
};

let task = download.start();

task.map(move |progress| Message::DownloadProgressed(index, progress))
}
Message::DownloadProgressed(id, progress) => {
if let Some(download) = self.downloads.iter_mut().find(|download| download.id == id)
{
download.progress(progress);
}

Task::none()
}
}
}

fn view(&self) -> Element<Message> {
let downloads = Column::with_children(self.downloads.iter().map(Download::view))
.push(
button("Add another download")
.on_press(Message::Add)
.padding(10),
)
.spacing(20)
.align_x(Right);

center(downloads).padding(20).into()
}
}

impl Default for Example {
fn default() -> Self {
Self::new()
}
}

#[derive(Debug)]
struct Download {
id: usize,
state: State,
}

#[derive(Debug)]
enum State {
Idle,
Downloading { progress: f32, _task: task::Handle },
Finished,
Errored,
}

impl Download {
pub fn new(id: usize) -> Self {
Download {
id,
state: State::Idle,
}
}

pub fn start(&mut self) -> Task<Result<Progress, Error>> {
match self.state {
State::Idle { .. } | State::Finished { .. } | State::Errored { .. } => {
let (task, handle) = Task::stream(download(
"mattshumer/Reflection-Llama-3.1-70B".to_string(),
"model-00001-of-00162.safetensors",
))
.abortable();

self.state = State::Downloading {
progress: 0.0,
_task: handle.abort_on_drop(),
};

task
}
State::Downloading { .. } => Task::none(),
}
}

pub fn progress(&mut self, new_progress: Result<Progress, Error>) {
if let State::Downloading { progress, .. } = &mut self.state {
match new_progress {
Ok(Progress::Downloading { current, total }) => {
println!("Status {progress} - {current}");
let new_progress = current as f32 / total as f32 * 100.0;
println!("New progress {current} {new_progress}");
*progress += new_progress;
}
Ok(Progress::Finished) => {
self.state = State::Finished;
}
Err(_error) => {
self.state = State::Errored;
}
}
}
}

pub fn view(&self) -> Element<Message> {
let current_progress = match &self.state {
State::Idle { .. } => 0.0,
State::Downloading { progress, .. } => *progress,
State::Finished { .. } => 100.0,
State::Errored { .. } => 0.0,
};

let progress_bar = progress_bar(0.0..=100.0, current_progress);

let control: Element<_> = match &self.state {
State::Idle => button("Start the download!")
.on_press(Message::Download(self.id))
.into(),
State::Finished => column!["Download finished!", button("Start again")]
.spacing(10)
.align_x(Center)
.into(),
State::Downloading { .. } => text!("Downloading... {current_progress:.2}%").into(),
State::Errored => column![
"Something went wrong :(",
button("Try again").on_press(Message::Download(self.id)),
]
.spacing(10)
.align_x(Center)
.into(),
};

Column::new()
.spacing(10)
.padding(10)
.align_x(Center)
.push(progress_bar)
.push(control)
.into()
}
}
47 changes: 47 additions & 0 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use indicatif::{ProgressBar, ProgressStyle};
use serde::Deserialize;

/// The asynchronous version of the API
Expand All @@ -8,6 +9,52 @@ pub mod tokio;
#[cfg(feature = "ureq")]
pub mod sync;

/// This trait is used by users of the lib
/// to implement custom behavior during file downloads
pub trait Progress {
/// At the start of the download
/// The size is the total size in bytes of the file.
fn init(&mut self, size: usize, filename: &str);
/// This function is called whenever `size` bytes have been
/// downloaded in the temporary file
fn update(&mut self, size: usize);
/// This is called at the end of the download
fn finish(&mut self);
}

impl Progress for () {
fn init(&mut self, _size: usize, _filename: &str) {}
fn update(&mut self, _size: usize) {}
fn finish(&mut self) {}
}

impl Progress for ProgressBar {
fn init(&mut self, size: usize, filename: &str) {
self.set_length(size as u64);
self.set_style(
ProgressStyle::with_template(
"{msg} [{elapsed_precise}] [{wide_bar}] {bytes}/{total_bytes} {bytes_per_sec} ({eta})",
)
.unwrap(), // .progress_chars("━ "),
);
let maxlength = 30;
let message = if filename.len() > maxlength {
format!("..{}", &filename[filename.len() - maxlength..])
} else {
filename.to_string()
};
self.set_message(message);
}

fn update(&mut self, size: usize) {
self.inc(size as u64)
}

fn finish(&mut self) {
ProgressBar::finish(self);
}
}

/// Siblings are simplified file descriptions of remote files on the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct Siblings {
Expand Down
Loading
Loading