Skip to content

Commit

Permalink
Merge pull request #3 from Kittensx/Kittensx-patch-Simple-KES
Browse files Browse the repository at this point in the history
Update simple_karras_exponential_scheduler.py
  • Loading branch information
Kittensx authored Oct 27, 2024
2 parents 6ffb728 + da2e709 commit 4ea5403
Showing 1 changed file with 46 additions and 50 deletions.
96 changes: 46 additions & 50 deletions modules/simple_karras_exponential_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,26 @@
import logging
from datetime import datetime

def get_random_or_default(scheduler_config, key_prefix, default_value, global_randomize):
"""Helper function to either randomize a value based on conditions or return the default."""

# Determine if we should randomize based on global and individual flags
randomize_flag = global_randomize or scheduler_config.get(f'{key_prefix}_rand', False)

if randomize_flag:
# Use specified min/max values for randomization if they exist, else use default range
rand_min = scheduler_config.get(f'{key_prefix}_rand_min', default_value * 0.8)
rand_max = scheduler_config.get(f'{key_prefix}_rand_max', default_value * 1.2)
value = random.uniform(rand_min, rand_max)
custom_logger.info(f"Randomized {key_prefix}: {value}")
else:
# Use default value if no randomization is applied
value = default_value
custom_logger.info(f"Using default {key_prefix}: {value}")

return value


class CustomLogger:
def __init__(self, log_name, print_to_console=False, debug_enabled=False):
self.print_to_console = print_to_console #prints to console
Expand Down Expand Up @@ -149,38 +169,8 @@ def start_config_watcher(config_manager, config_path):

# Start watching for config changes
observer = start_config_watcher(config_manager, config_path)
'''
def get_random_or_default(config, key_prefix, default_value):
"""Helper function to either randomize a value or return the default."""
randomize_flag = config['scheduler'].get(f'{key_prefix}_rand', False)
if randomize_flag:
rand_min = config['scheduler'].get(f'{key_prefix}_rand_min', default_value * 0.8)
rand_max = config['scheduler'].get(f'{key_prefix}_rand_max', default_value * 1.2)
value = random.uniform(rand_min, rand_max)
custom_logger.info(f"Randomized {key_prefix}: {value}" )
else:
value = default_value
custom_logger.info(f"Using default {key_prefix}: {value}")
return value
'''
def get_random_or_default(config, key_prefix, default_value, global_randomize):
"""Helper function to either randomize a value based on conditions or return the default."""
# Check if global randomize is on or the individual flag is on
randomize_flag = global_randomize or config['scheduler'].get(f'{key_prefix}_rand', False)

if randomize_flag:
# Use specified min/max for randomization if the individual flag is set or global randomize is on
rand_min = config['scheduler'].get(f'{key_prefix}_rand_min', default_value * 0.8)
rand_max = config['scheduler'].get(f'{key_prefix}_rand_max', default_value * 1.2)
value = random.uniform(rand_min, rand_max)
custom_logger.info(f"Randomized {key_prefix}: {value}")
else:
value = default_value
custom_logger.info(f"Using default {key_prefix}: {value}")

return value


def simple_karras_exponential_scheduler(
n, device, sigma_min=0.01, sigma_max=50, start_blend=0.1, end_blend=0.5,
sharpness=0.95, early_stopping_threshold=0.01, update_interval=10, initial_step_size=0.9,
Expand Down Expand Up @@ -209,6 +199,15 @@ def simple_karras_exponential_scheduler(
Returns:
torch.Tensor: A tensor of blended sigma values.
"""
config_path = os.path.join(os.path.dirname(__file__), 'simple_kes_scheduler.yaml')
config = config_manager.load_config()
scheduler_config = config.get('scheduler', {})
if not scheduler_config:
raise ValueError("Scheduler configuration is missing from the config file.")

# Global randomization flag
global_randomize = scheduler_config.get('randomize', False)

#debug_log("Entered simple_karras_exponential_scheduler function")
default_config = {
"debug": False,
Expand Down Expand Up @@ -272,30 +271,27 @@ def simple_karras_exponential_scheduler(
"noise_scale_factor_rand_max": 0.95,
}
custom_logger.info(f"Default Config create {default_config}")
for key, value in default_config.items():
custom_logger.info(f"Default Config - {key}: {value}")

#config = config_manager.load_config()
config = config_manager.load_config().get('scheduler', {})
global_randomize = config.get('randomize', randomize)

custom_logger.info(f"Config loaded from yaml {config}")
if not config:
raise ValueError("Scheduler configuration is missing from the config file.")

# Log loaded YAML configuration
custom_logger.info(f"Configuration loaded from YAML: {config}")

for key, value in config.items():
custom_logger.info(f"Config - {key}: {value}")

# Check if the scheduler config is available in the YAML file
scheduler_config = config.get('scheduler', {})
if not scheduler_config:
raise ValueError("Scheduler configuration is missing from the config file.")

for key, value in scheduler_config.items():
custom_logger.info(f"Scheduler Config before update - {key}: {value}")
for key, value in scheduler_config.items():
if key in default_config:
default_config[key] = value
default_config[key] = value # Override default with YAML value
custom_logger.info(f"Overriding default config: {key} = {value}")
else:
debug.log(f"Ignoring unknown config option: {key}")
custom_logger.info(f"Ignoring unknown config option: {key}")

custom_logger.info(f"Final configuration after merging with YAML: {default_config}")

global_randomize = default_config.get('randomize', False)
custom_logger.info(f"Global randomization flag set to: {global_randomize}")

custom_logger.info(f"Config loaded from yaml {config}")

# Now using default_config, updated with valid YAML values
custom_logger.info(f"Final Config after overriding: {default_config}")

Expand Down

0 comments on commit 4ea5403

Please sign in to comment.