Source code for symjax.rl.ddpg

# main class

import symjax
import numpy as np
from symjax import nn
import symjax.tensor as T
from . import agents

# https://gist.github.com/heerad/1983d50c6657a55298b67e69a2ceeb44


# class DDPG(Agent):


[docs]class DDPG(agents.Agent): def __init__( self, state_shape, actions_shape, batch_size, actor, critic, lr=1e-3, gamma=0.99, tau=0.01, ): self.gamma = gamma self.tau = tau self.lr = lr self.batch_size = batch_size states = T.Placeholder((batch_size,) + state_shape, "float32") actions = T.Placeholder((batch_size,) + actions_shape, "float32") self.critic = critic(states, actions) self.target_critic = critic(states, actions) # create critic loss targets = T.Placeholder(self.critic.q_values.shape, "float32") critic_loss = ((self.critic.q_values - targets) ** 2).mean() # create optimizer with symjax.Scope("critic_optimizer"): nn.optimizers.Adam(critic_loss, lr, params=self.critic.params(True)) # create the update function self._train_critic = symjax.function( states, actions, targets, outputs=critic_loss, updates=symjax.get_updates(scope="*/critic_optimizer"), ) # now create utility function to get the gradients grad = symjax.gradients(self.critic.q_values.sum(), actions) self._get_critic_gradients = symjax.function(states, actions, outputs=grad) # create actor loss self.actor = actor(states) self.target_actor = actor(states) gradients = T.Placeholder(actions.shape, "float32") actor_loss = -(self.actor.actions * gradients).mean() # create optimizer with symjax.Scope("actor_optimizer"): nn.optimizers.Adam(actor_loss, lr, params=self.actor.params(True)) # create the update function self._train_actor = symjax.function( states, gradients, outputs=actor_loss, updates=symjax.get_updates(scope="*/actor_optimizer"), ) # initialize both networks as the same self.update_target(1) def train(self, buffer, *args, **kwargs): s, a, r, s2, t = buffer.sample(self.batch_size) # Calculate the target for the critic a2 = self.target_actor.get_actions(s2)[0] q_values = self.target_critic.get_q_values(s2, a2) targets = r + (1 - t.astype("float32")) * self.gamma * q_values.squeeze() c_loss = self._train_critic(s, a, targets) # if not self.continuous: # a = (np.arange(self.num_actions) == a[:, None]).astype("float32") actions = self.actor.get_actions(s)[0] gradients = self._get_critic_gradients(s, actions) a_loss = self._train_actor(s, gradients) self.update_target() return a_loss, c_loss
[docs]class REINFORCE(agents.Agent): """ policy gradient reinforce also called reward-to-go policy gradient the vanilla policy gradient uses the total reward of each episode as a weight. In this implementation it is the discounted rewards to go that are used. Setting ``gamma`` to 1 leads to the reward to go policy gradient https://medium.com/@thechrisyoon/deriving-policy-gradients-and-implementing-reinforce-f887949bd63 """ def __init__( self, state_shape, actions_shape, n_episodes, episode_length, actor, lr=1e-3, gamma=0.99, ): self.actor = actor self.gamma = gamma self.lr = lr self.episode_length = episode_length self.n_episodes = n_episodes self.batch_size = episode_length * n_episodes states = T.Placeholder((self.batch_size,) + state_shape, "float32") actions = T.Placeholder((self.batch_size,) + actions_shape, "float32") discounted_rewards = T.Placeholder((self.batch_size,), "float32") self.actor = actor(states, distribution="gaussian") logprobs = self.actor.actions.log_prob(actions) actor_loss = -(logprobs * discounted_rewards).sum() / n_episodes with symjax.Scope("REINFORCE_optimizer"): nn.optimizers.Adam( actor_loss, lr, params=self.actor.params(True), ) # create the update function self._train = symjax.function( states, actions, discounted_rewards, outputs=actor_loss, updates=symjax.get_updates(scope="*/REINFORCE_optimizer"), ) def train(self, buffer, *args, **kwargs): assert buffer.n_episodes == self.n_episodes indices = list(range(self.episode_length * self.n_episodes)) states, actions, disc_rewards = buffer.sample( indices, ["state", "action", "reward-to-go"], ) disc_rewards -= disc_rewards.mean() disc_rewards /= disc_rewards.std() loss = self._train(states, actions, disc_rewards) buffer.reset() return loss
[docs]class ActorCritic(agents.Agent): """ this corresponds to Q actor critic or V actor critic depending on the given critic (with GAE-Lambda for advantage estimation) https://www.freecodecamp.org/news/an-intro-to-advantage-actor-critic-methods-lets-play-sonic-the-hedgehog-86d6240171d/ """ def __init__( self, state_shape, actions_shape, n_episodes, episode_length, actor, critic, lr=1e-3, gamma=0.99, train_v_iters=10, ): self.actor = actor self.critic = critic self.gamma = gamma self.lr = lr self.episode_length = episode_length self.n_episodes = n_episodes self.batch_size = episode_length * n_episodes self.train_v_iters = train_v_iters states = T.Placeholder((self.batch_size,) + state_shape, "float32") actions = T.Placeholder((self.batch_size,) + actions_shape, "float32") discounted_rewards = T.Placeholder((self.batch_size,), "float32") advantages = T.Placeholder((self.batch_size,), "float32") self.actor = actor(states, distribution="gaussian") self.critic = critic(states) logprobs = self.actor.actions.log_prob(actions) actor_loss = -(logprobs * advantages).sum() / n_episodes critic_loss = 0.5 * ((discounted_rewards - self.critic.q_values) ** 2).mean() with symjax.Scope("actor_optimizer"): nn.optimizers.Adam( actor_loss, lr, params=self.actor.params(True), ) with symjax.Scope("critic_optimizer"): nn.optimizers.Adam( critic_loss, lr, params=self.critic.params(True), ) # create the update function self._train_actor = symjax.function( states, actions, advantages, outputs=actor_loss, updates=symjax.get_updates(scope="*/actor_optimizer"), ) # create the update function self._train_critic = symjax.function( states, discounted_rewards, outputs=critic_loss, updates=symjax.get_updates(scope="*/critic_optimizer"), ) def train(self, buffer, *args, **kwargs): indices = list(range(self.batch_size)) states, actions, disc_rewards, advantages = buffer.sample( indices, ["state", "action", "reward-to-go", "advantage"], ) advantages -= advantages.mean() advantages /= advantages.std() actor_loss = self._train_actor(states, actions, advantages) for i in range(self.train_v_iters): critic_loss = self._train_critic(states, disc_rewards) buffer.reset() return actor_loss, critic_loss
[docs]class PPO(agents.Agent): """ instead of using target networks one can record the old log probs have better advantage estimates """ def __init__( self, state_shape, actions_shape, batch_size, actor, critic, lr=1e-3, K_epochs=80, eps_clip=0.2, gamma=0.99, entropy_beta=0.01, ): self.actor = actor self.critic = critic self.gamma = gamma self.lr = lr self.eps_clip = eps_clip self.K_epochs = K_epochs self.batch_size = batch_size states = T.Placeholder((batch_size,) + state_shape, "float32", name="states") actions = T.Placeholder((batch_size,) + actions_shape, "float32", name="states") rewards = T.Placeholder((batch_size,), "float32", name="discounted_rewards") advantages = T.Placeholder((batch_size,), "float32", name="advantages") self.target_actor = actor(states, distribution="gaussian") self.actor = actor(states, distribution="gaussian") self.critic = critic(states) # Finding the ratio (pi_theta / pi_theta__old) and # surrogate Loss https://arxiv.org/pdf/1707.06347.pdf with symjax.Scope("policy_loss"): ratios = T.exp( self.actor.actions.log_prob(actions) - self.target_actor.actions.log_prob(actions) ) ratios = T.clip(ratios, 0, 10) clipped_ratios = T.clip(ratios, 1 - self.eps_clip, 1 + self.eps_clip) surr1 = advantages * ratios surr2 = advantages * clipped_ratios actor_loss = -(T.minimum(surr1, surr2)).mean() with symjax.Scope("monitor"): clipfrac = ( ((ratios > (1 + self.eps_clip)) | (ratios < (1 - self.eps_clip))) .astype("float32") .mean() ) approx_kl = ( self.target_actor.actions.log_prob(actions) - self.actor.actions.log_prob(actions) ).mean() with symjax.Scope("critic_loss"): critic_loss = T.mean((rewards - self.critic.q_values) ** 2) with symjax.Scope("entropy"): entropy = self.actor.actions.entropy().mean() loss = actor_loss + critic_loss # - entropy_beta * entropy with symjax.Scope("optimizer"): nn.optimizers.Adam( loss, lr, params=self.actor.params(True) + self.critic.params(True), ) # create the update function self._train = symjax.function( states, actions, rewards, advantages, outputs=[actor_loss, critic_loss, clipfrac, approx_kl], updates=symjax.get_updates(scope="*optimizer"), ) # initialize target as current self.update_target(1) def train(self, buffer, *args, **kwargs): indices = list(range(buffer.length)) states, actions, rewards, advantages = buffer.sample( indices, ["state", "action", "reward-to-go", "advantage"], ) # Optimize policy for K epochs: advantages -= advantages.mean() advantages /= advantages.std() for _ in range(self.K_epochs): for s, a, r, adv in symjax.data.utils.batchify( states, actions, rewards, advantages, batch_size=self.batch_size, ): loss = self._train(s, a, r, adv) print([v.value for v in symjax.get_variables(name="logsigma")]) buffer.reset_data() self.update_target(1) return loss