From 45842865ff74a632d23eafd3c91b9717f8cf2cef Mon Sep 17 00:00:00 2001 From: leejet Date: Sun, 3 Sep 2023 20:08:22 +0800 Subject: [PATCH] fix: seed should be 64 bit --- examples/main.cpp | 6 +++--- rng.h | 4 ++-- rng_philox.h | 2 +- stable-diffusion.cpp | 4 ++-- stable-diffusion.h | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/main.cpp b/examples/main.cpp index 3a1817c5..cfb73c1f 100644 --- a/examples/main.cpp +++ b/examples/main.cpp @@ -87,7 +87,7 @@ struct Option { int sample_steps = 20; float strength = 0.75f; RNGType rng_type = STD_DEFAULT_RNG; - int seed = 42; + int64_t seed = 42; bool verbose = false; void print() { @@ -106,7 +106,7 @@ struct Option { printf(" sample_steps: %d\n", sample_steps); printf(" strength: %.2f\n", strength); printf(" rng: %s\n", rng_type_to_str[rng_type]); - printf(" seed: %d\n", seed); + printf(" seed: %ld\n", seed); } }; @@ -233,7 +233,7 @@ void parse_args(int argc, const char* argv[], Option* opt) { invalid_arg = true; break; } - opt->seed = std::stoi(argv[i]); + opt->seed = std::stoll(argv[i]); } else if (arg == "-h" || arg == "--help") { print_usage(argc, argv); exit(0); diff --git a/rng.h b/rng.h index a3cb974a..e8942605 100644 --- a/rng.h +++ b/rng.h @@ -6,7 +6,7 @@ class RNG { public: - virtual void manual_seed(uint32_t seed) = 0; + virtual void manual_seed(uint64_t seed) = 0; virtual std::vector randn(uint32_t n) = 0; }; @@ -15,7 +15,7 @@ class STDDefaultRNG : public RNG { std::default_random_engine generator; public: - void manual_seed(uint32_t seed) { + void manual_seed(uint64_t seed) { generator.seed(seed); } diff --git a/rng_philox.h b/rng_philox.h index a159c9aa..c9b70fc2 100644 --- a/rng_philox.h +++ b/rng_philox.h @@ -93,7 +93,7 @@ class PhiloxRNG : public RNG { this->offset = 0; } - void manual_seed(uint32_t seed) { + void manual_seed(uint64_t seed) { this->seed = seed; this->offset = 0; } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 3eb09723..0063f1fe 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -3823,7 +3823,7 @@ std::vector StableDiffusion::txt2img(const std::string& prompt, int height, SampleMethod sample_method, int sample_steps, - int seed) { + int64_t seed) { std::vector result; struct ggml_init_params params; params.mem_size = static_cast(10 * 1024) * 1024; // 10M @@ -3911,7 +3911,7 @@ std::vector StableDiffusion::img2img(const std::vector& init_i SampleMethod sample_method, int sample_steps, float strength, - int seed) { + int64_t seed) { std::vector result; if (init_img_vec.size() != width * height * 3) { return result; diff --git a/stable-diffusion.h b/stable-diffusion.h index 11c892fa..41122653 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -40,7 +40,7 @@ class StableDiffusion { int height, SampleMethod sample_method, int sample_steps, - int seed); + int64_t seed); std::vector img2img( const std::vector& init_img, const std::string& prompt, @@ -51,7 +51,7 @@ class StableDiffusion { SampleMethod sample_method, int sample_steps, float strength, - int seed); + int64_t seed); }; void set_sd_log_level(SDLogLevel level);