Skip to content

Commit

Permalink
Refactor SAC configuration and policy for improved action sampling an…
Browse files Browse the repository at this point in the history
…d stability

- Updated SACConfig to replace standard deviation parameterization with log_std_min and log_std_max for better control over action distributions.
- Modified SACPolicy to streamline action selection and log probability calculations, enhancing stochastic behavior.
- Removed deprecated TanhMultivariateNormalDiag class to simplify the codebase and improve maintainability.

These changes aim to enhance the robustness and performance of the SAC implementation during training and inference.
  • Loading branch information
KeWang1017 authored and michel-aractingi committed Dec 29, 2024
1 parent 22fbc9e commit 5b4adc0
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 218 deletions.
27 changes: 5 additions & 22 deletions lerobot/common/policies/sac/configuration_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,30 +53,13 @@ class SACConfig:
critic_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,
}
}
actor_network_kwargs = {
"hidden_dims": [256, 256],
"activate_final": True,
}
policy_kwargs = {
"tanh_squash_distribution": True,
"std_parameterization": "softplus",
"std_min": 0.005,
"std_max": 5.0,
}
)
output_shapes: dict[str, list[int]] = field(
default_factory=lambda: {
"action": [4],
policy_kwargs = {
"use_tanh_squash": True,
"log_std_min": -5,
"log_std_max": 2,
}
)

state_encoder_hidden_dim: int = 256
latent_dim: int = 256
network_hidden_dims: int = 256

# Normalization / Unnormalization
input_normalization_modes: dict[str, str] | None = None
output_normalization_modes: dict[str, str] = field(
default_factory=lambda: {"action": "min_max"},
)
235 changes: 39 additions & 196 deletions lerobot/common/policies/sac/modeling_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def reset(self):
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
distribution = self.actor(batch)
# Sample from the distribution and return just the actions
actions = distribution.mode() # or distribution.sample() for stochastic actions
actions, _ = self.actor(batch)
actions = self.unnormalize_outputs({"action": actions})["action"]
return actions

Expand All @@ -129,12 +127,11 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:

# reward bias from HIL-SERL code base
# add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch

# calculate critics loss
# 1- compute actions from policy
distribution = self.actor(observations)
action_preds = distribution.sample()
action_preds = torch.clamp(action_preds, -1, +1)
action_preds, log_probs = self.actor(next_observations)

# 2- compute q targets
q_targets = self.target_qs(next_observations, action_preds)
# subsample critics to prevent overfitting if use high UTD (update to date)
Expand All @@ -147,7 +144,7 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
min_q = q_targets.min(dim=0)

# compute td target
td_target = rewards + self.discount * min_q
td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term

# 3- compute predicted qs
q_preds = self.critic_ensemble(observations, actions)
Expand Down Expand Up @@ -178,18 +175,12 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]:
einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape
reduction="none"
).sum(0).mean()
# breakpoint()

# calculate actors loss
# 1- temperature
temperature = self.temperature()

# 2- get actions (batch_size, action_dim) and log probs (batch_size,)
distribution = self.actor(observations)
actions = distribution.rsample()
log_probs = distribution.log_prob(actions).sum(-1)
# breakpoint()
actions = torch.clamp(actions, -1, +1)
actions, log_probs = self.actor(observations)
# 3- get q-value predictions
with torch.no_grad():
q_preds = self.critic_ensemble(observations, actions, return_type="mean")
Expand Down Expand Up @@ -264,15 +255,13 @@ def __init__(
encoder: Optional[nn.Module],
network: nn.Module,
init_final: Optional[float] = None,
activate_final: bool = False,
device: str = "cuda",
device: str = "cuda"
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.init_final = init_final
self.activate_final = activate_final

# Find the last Linear layer's output dimension
for layer in reversed(network.net):
Expand Down Expand Up @@ -304,49 +293,29 @@ def forward(self, observations: torch.Tensor, actions: torch.Tensor, train: bool
value = self.output_layer(x)
return value.squeeze(-1)

def q_value_ensemble(
self, observations: torch.Tensor, actions: torch.Tensor, train: bool = False
) -> torch.Tensor:
observations = observations.to(self.device)
actions = actions.to(self.device)

if len(actions.shape) == 3: # [batch_size, num_actions, action_dim]
batch_size, num_actions = actions.shape[:2]
obs_expanded = observations.unsqueeze(1).expand(-1, num_actions, -1)
obs_flat = obs_expanded.reshape(-1, observations.shape[-1])
actions_flat = actions.reshape(-1, actions.shape[-1])
q_values = self(obs_flat, actions_flat, train)
return q_values.reshape(batch_size, num_actions)
else:
return self(observations, actions, train)


class Policy(nn.Module):
def __init__(
self,
encoder: Optional[nn.Module],
network: nn.Module,
action_dim: int,
std_parameterization: str = "exp",
std_min: float = 0.05,
std_max: float = 2.0,
tanh_squash_distribution: bool = False,
log_std_min: float = -5,
log_std_max: float = 2,
fixed_std: Optional[torch.Tensor] = None,
init_final: Optional[float] = None,
activate_final: bool = False,
device: str = "cuda",
use_tanh_squash: bool = False,
device: str = "cuda"
):
super().__init__()
self.device = torch.device(device)
self.encoder = encoder
self.network = network
self.action_dim = action_dim
self.std_parameterization = std_parameterization
self.std_min = std_min
self.std_max = std_max
self.tanh_squash_distribution = tanh_squash_distribution
self.log_std_min = log_std_min
self.log_std_max = log_std_max
self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None
self.activate_final = activate_final
self.use_tanh_squash = use_tanh_squash

# Find the last Linear layer's output dimension
for layer in reversed(network.net):
Expand All @@ -364,27 +333,20 @@ def __init__(

# Standard deviation layer or parameter
if fixed_std is None:
if std_parameterization == "uniform":
self.log_stds = nn.Parameter(torch.zeros(action_dim, device=self.device))
self.std_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
self.std_layer = nn.Linear(out_features, action_dim)
if init_final is not None:
nn.init.uniform_(self.std_layer.weight, -init_final, init_final)
nn.init.uniform_(self.std_layer.bias, -init_final, init_final)
else:
orthogonal_init()(self.std_layer.weight)

orthogonal_init()(self.std_layer.weight)

self.to(self.device)

def forward(
self,
observations: torch.Tensor,
temperature: float = 1.0,
train: bool = False,
non_squash_distribution: bool = False,
) -> torch.distributions.Distribution:
self.train(train)

) -> Tuple[torch.Tensor, torch.Tensor]:

# Encode observations if encoder exists
if self.encoder is not None:
with torch.set_grad_enabled(train):
Expand All @@ -398,41 +360,24 @@ def forward(

# Compute standard deviations
if self.fixed_std is None:
if self.std_parameterization == "exp":
log_stds = self.std_layer(outputs)
# Clamp log_stds to prevent too large or small values
log_stds = torch.clamp(log_stds, math.log(self.std_min), math.log(self.std_max))
stds = torch.exp(log_stds)
elif self.std_parameterization == "softplus":
stds = torch.nn.functional.softplus(self.std_layer(outputs))
stds = torch.clamp(stds, self.std_min, self.std_max)
elif self.std_parameterization == "uniform":
log_stds = torch.clamp(self.log_stds, math.log(self.std_min), math.log(self.std_max))
stds = torch.exp(log_stds).expand_as(means)
else:
raise ValueError(f"Invalid std_parameterization: {self.std_parameterization}")
log_std = self.std_layer(outputs)
if self.use_tanh_squash:
log_std = torch.tanh(log_std)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)
else:
assert self.std_parameterization == "fixed"
stds = self.fixed_std.expand_as(means)

# Scale with temperature
temperature = torch.tensor(temperature, device=self.device)
stds = torch.clamp(stds, self.std_min, self.std_max) * torch.sqrt(temperature)

# Create distribution
if self.tanh_squash_distribution and not non_squash_distribution:
distribution = TanhMultivariateNormalDiag(
loc=means,
scale_diag=stds,
)
else:
distribution = torch.distributions.Normal(
loc=means,
scale=stds,
)

return distribution

# uses tahn activation function to squash the action to be in the range of [-1, 1]
normal = torch.distributions.Normal(means, stds)
x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1))
log_probs = normal.log_prob(x_t)
if self.use_tanh_squash:
actions = torch.tanh(x_t)
log_probs -= torch.log((1 - actions.pow(2)) + 1e-6)
log_probs = log_probs.sum(-1) # sum over action dim

return actions, log_probs

def get_features(self, observations: torch.Tensor) -> torch.Tensor:
"""Get encoded features from observations"""
observations = observations.to(self.device)
Expand Down Expand Up @@ -552,110 +497,8 @@ def forward(self, lhs: Optional[torch.Tensor] = None, rhs: Optional[torch.Tensor
return multiplier * diff


# The TanhMultivariateNormalDiag is a probability distribution that represents a transformed normal (Gaussian) distribution where:
# 1. The base distribution is a diagonal multivariate normal distribution
# 2. The samples from this normal distribution are transformed through a tanh function, which squashes the values to be between -1 and 1
# 3. Optionally, the values can be further transformed to fit within arbitrary bounds [low, high] using an affine transformation
# This type of distribution is commonly used in reinforcement learning, particularly for continuous action spaces
class TanhMultivariateNormalDiag(torch.distributions.TransformedDistribution):
DEFAULT_SAMPLE_SHAPE = torch.Size()

def __init__(
self,
loc: torch.Tensor,
scale_diag: torch.Tensor,
low: Optional[torch.Tensor] = None,
high: Optional[torch.Tensor] = None,
):
# Create base normal distribution
base_distribution = torch.distributions.Normal(loc=loc, scale=scale_diag)

# Create list of transforms
transforms = []

# Add tanh transform
transforms.append(torch.distributions.transforms.TanhTransform())

# Add rescaling transform if bounds are provided
if low is not None and high is not None:
transforms.append(
torch.distributions.transforms.AffineTransform(loc=(high + low) / 2, scale=(high - low) / 2)
)

# Initialize parent class
super().__init__(base_distribution=base_distribution, transforms=transforms)

# Store parameters
self.loc = loc
self.scale_diag = scale_diag
self.low = low
self.high = high

def mode(self) -> torch.Tensor:
"""Get the mode of the transformed distribution"""
# The mode of a normal distribution is its mean
mode = self.loc
# Apply transforms
for transform in self.transforms:
mode = transform(mode)

return mode

def rsample(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> torch.Tensor:
"""
Reparameterized sample from the distribution
"""
# Sample from base distributionrsample
x = self.base_dist.rsample(sample_shape)

# Apply transforms
for transform in self.transforms:
x = transform(x)

return x

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
"""
Compute log probability of a value
Includes the log det jacobian for the transforms
"""
# Initialize log prob
log_prob = torch.zeros_like(value)

# Inverse transforms to get back to normal distribution
q = value
for transform in reversed(self.transforms):
q_prev = transform.inv(q) # Get the pre-transform value
log_prob = log_prob - transform.log_abs_det_jacobian(q_prev, q) # Sum over action dimensions
q = q_prev

# Add base distribution log prob
log_prob = log_prob + self.base_dist.log_prob(q) # Sum over action dimensions

return log_prob

def sample_and_log_prob(self, sample_shape=DEFAULT_SAMPLE_SHAPE) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Sample from the distribution and compute log probability
"""
x = self.rsample(sample_shape)
log_prob = self.log_prob(x)
return x, log_prob

# def entropy(self) -> torch.Tensor:
# """
# Compute entropy of the distribution
# """
# # Start with base distribution entropy
# entropy = self.base_dist.entropy().sum(-1)

# # Add log det jacobian for each transform
# x = self.rsample()
# for transform in self.transforms:
# entropy = entropy + transform.log_abs_det_jacobian(x, transform(x))
# x = transform(x)

# return entropy
def orthogonal_init():
return lambda x: torch.nn.init.orthogonal_(x, gain=1.0)


def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList:
Expand Down

0 comments on commit 5b4adc0

Please sign in to comment.