Module sac-pendulum.agent

Classes

class Agent (n_obs: int, n_actions: int, max_action: float, checkpoint_dir: pathlib.Path | str, buffer_size: int = 1000000, actor_lr: float = 0.0003, critic_lr: float = 0.0003, alpha: float = 0.2, gamma: float = 0.99, tau: float = 0.005, hidden_dims: int = 256, reward_scale: int = 2, batch_size: int = 256)
Expand source code
class Agent:
    def __init__(
        self,
        n_obs: int,
        n_actions: int,
        max_action: float,
        checkpoint_dir: Path | str,
        buffer_size: int = 1_000_000,
        actor_lr: float = 0.0003,
        critic_lr: float = 0.0003,
        alpha: float = 0.2,
        gamma: float = 0.99,
        tau: float = 0.005,
        hidden_dims: int = 256,
        reward_scale: int = 2,
        batch_size: int = 256,
    ) -> None:
        self.checkpoint_dir = checkpoint_dir
        self.batch_size = batch_size
        self.n_obs = n_obs
        self.n_actions = n_actions
        self.max_action = max_action
        self.buffer_size = buffer_size
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        self.alpha = alpha
        self.gamma = gamma
        self.tau = tau
        self.hidden_dims = hidden_dims
        self.reward_scale = reward_scale

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.actor = ActorNetwork(n_obs, n_actions, max_action, hidden_dims).to(
            self.device
        )
        self.critic_1 = CriticNetwork(n_obs, n_actions, hidden_dims).to(self.device)
        self.critic_2 = CriticNetwork(n_obs, n_actions, hidden_dims).to(self.device)
        self.value = ValueNetwork(n_obs, hidden_dims).to(self.device)
        self.value_target = ValueNetwork(n_obs, hidden_dims).to(self.device)

        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=critic_lr)
        self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=critic_lr)
        self.value_optimizer = optim.Adam(self.value.parameters(), lr=critic_lr)

        self.buffer = ReplayBuffer(buffer_size)
        self.update_network_parameters(1)

    def select_action(self, state):
        state = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(self.device)
        actions, _ = self.actor.sample(state)
        return actions.cpu().detach().numpy()[0]

    def remember(self, state, action, reward, next_state, done):
        self.buffer.add(Frame(state, action, reward, next_state, done))

    def update_network_parameters(self, tau=None):
        if tau is None:
            tau = self.tau

        for target_param, param in zip(
            self.value_target.parameters(), self.value.parameters()
        ):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

    def learn(self):
        if len(self.buffer) < self.batch_size:
            return

        states, actions, rewards, next_states, dones = zip(
            *self.buffer.sample(self.batch_size)
        )

        states = torch.tensor(states, dtype=torch.float).to(self.device)
        actions = torch.tensor(actions, dtype=torch.float).to(self.device)
        rewards = torch.tensor(rewards, dtype=torch.float).to(self.device).unsqueeze(1)
        next_states = torch.tensor(next_states, dtype=torch.float).to(self.device)
        dones = torch.tensor(dones, dtype=torch.float).to(self.device).unsqueeze(1)

        value = self.value(states).view(-1)
        next_value = self.value_target(next_states).view(-1)
        next_value = next_value * (1 - dones).view(-1)
        actions_new, log_probs = self.actor.sample(states)
        log_probs = log_probs.view(-1)
        q1_new_policy = self.critic_1(states, actions_new).view(-1)
        q2_new_policy = self.critic_2(states, actions_new).view(-1)
        critic_value = torch.min(q1_new_policy, q2_new_policy)

        self.value_optimizer.zero_grad()
        value_target = critic_value - log_probs
        value_loss = F.mse_loss(value, value_target)
        value_loss.backward()
        self.value_optimizer.step()

        actions_new, log_probs = self.actor.sample(states)
        q1_new_policy = self.critic_1(states, actions_new).view(-1)
        q2_new_policy = self.critic_2(states, actions_new).view(-1)
        critic_value = torch.min(q1_new_policy, q2_new_policy)

        actor_loss = (self.alpha * log_probs - critic_value).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward(retain_graph=True)
        self.actor_optimizer.step()

        self.critic_1_optimizer.zero_grad()
        self.critic_2_optimizer.zero_grad()
        q_hat = self.reward_scale * rewards + self.gamma * next_value.unsqueeze(1)
        q1_old_policy = self.critic_1(states, actions).view(-1)
        q2_old_policy = self.critic_2(states, actions).view(-1)
        critic_1_loss = F.mse_loss(q1_old_policy, q_hat.view(-1))
        critic_2_loss = F.mse_loss(q2_old_policy, q_hat.view(-1))

        critic_loss = critic_1_loss + critic_2_loss
        critic_loss.backward()
        self.critic_1_optimizer.step()
        self.critic_2_optimizer.step()

        self.update_network_parameters()

    def save_models(self):
        print("saving models\n===================================")
        torch.save(self.actor.state_dict(), f"{self.checkpoint_dir}/actor.pth")
        torch.save(self.value.state_dict(), f"{self.checkpoint_dir}/value.pth")
        torch.save(
            self.value_target.state_dict(), f"{self.checkpoint_dir}/value_target.pth"
        )
        torch.save(self.critic_1.state_dict(), f"{self.checkpoint_dir}/critic_1.pth")
        torch.save(self.critic_2.state_dict(), f"{self.checkpoint_dir}/critic_2.pth")

        torch.save(
            self.actor_optimizer.state_dict(),
            f"{self.checkpoint_dir}/actor_optimizer.pth",
        )
        torch.save(
            self.value_optimizer.state_dict(),
            f"{self.checkpoint_dir}/value_optimizer.pth",
        )
        torch.save(
            self.critic_1_optimizer.state_dict(),
            f"{self.checkpoint_dir}/critic_1_optimizer.pth",
        )
        torch.save(
            self.critic_2_optimizer.state_dict(),
            f"{self.checkpoint_dir}/critic_2_optimizer.pth",
        )

    def load_models(self):
        print("loading models\n===================================")
        self.actor.load_state_dict(
            torch.load(f"{self.checkpoint_dir}/actor.pth", weights_only=True)
        )
        self.value.load_state_dict(
            torch.load(f"{self.checkpoint_dir}/value.pth", weights_only=True)
        )
        self.value_target.load_state_dict(
            torch.load(f"{self.checkpoint_dir}/value_target.pth", weights_only=True),
        )
        self.critic_1.load_state_dict(
            torch.load(f"{self.checkpoint_dir}/critic_1.pth", weights_only=True)
        )
        self.critic_2.load_state_dict(
            torch.load(f"{self.checkpoint_dir}/critic_2.pth", weights_only=True)
        )

        self.actor_optimizer.load_state_dict(
            torch.load(f"{self.checkpoint_dir}/actor_optimizer.pth", weights_only=True)
        )
        self.value_optimizer.load_state_dict(
            torch.load(f"{self.checkpoint_dir}/value_optimizer.pth", weights_only=True)
        )
        self.critic_1_optimizer.load_state_dict(
            torch.load(
                f"{self.checkpoint_dir}/critic_1_optimizer.pth", weights_only=True
            )
        )
        self.critic_2_optimizer.load_state_dict(
            torch.load(
                f"{self.checkpoint_dir}/critic_2_optimizer.pth", weights_only=True
            )
        )

Methods

def learn(self)
def load_models(self)
def remember(self, state, action, reward, next_state, done)
def save_models(self)
def select_action(self, state)
def update_network_parameters(self, tau=None)