import symjax
import symjax.tensor as T
import numpy as np
class Agent(object):
def get_action(self, state):
return self.actor.get_action(state)
def get_actions(self, states):
return self.actor.get_actions(states)
def play(self, state, env, skip_frames=1, reward_scaling=1):
action, extra = self.get_action(state)
if hasattr(self, "critic"):
if hasattr(self.critic, "actions"):
value = self.critic.get_q_value(state, action)
else:
value = self.critic.get_q_value(state)
extra.update({"V": value})
reward = 0
for k in range(skip_frames):
next_state, r, terminal, info = env.step(action)
reward += r
if terminal:
break
reward /= skip_frames
reward *= reward_scaling
base = {
"state": state,
"action": action,
"reward": reward,
"next-state": next_state,
"terminal": terminal,
}
return next_state, terminal, {**base, **extra}
def update_target(self, tau=None):
if not hasattr(self, "_update_target"):
with symjax.Scope("update_target"):
targets = []
currents = []
if hasattr(self, "target_actor"):
targets += self.target_actor.params(True)
currents += self.actor.params(True)
if hasattr(self, "target_critic"):
targets += self.target_critic.params(True)
currents += self.critic.params(True)
_tau = T.Placeholder((), "float32")
updates = {
t: t * (1 - _tau) + a * _tau for t, a in zip(targets, currents)
}
self._update_target = symjax.function(_tau, updates=updates)
if tau is None:
if not hasattr(self, "tau"):
raise RuntimeError("tau must be specified")
tau = tau or self.tau
self._update_target(tau)
[docs]class Actor(object):
"""actor (state to action mapping) for RL
This class implements an actor. The user must first define its own class
inheriting from :py:class:`Actor` and implementing only the
`create_network` method. This method will then be used internally to
instantiace the actor network.
If the used distribution is `symjax.probabilities.Normal` then the output
of the `create_network` method should be first the mean and then the
`covariance`.
Notes:
------
In general the user should not instanciate this class, instead pass the
user's inherited class (uninstanciated) to a policy-learning method.
Parameters:
-----------
states: Tensor-like
the states of the environment (batch size in first axis)
batch_size: int
the batch size
actions_distribution: None or symjax.probabilities.Distribution object
the distribution for the actions, if the policy is deterministic, then
put this to `None`. Note, this is different than the noise parameter
employed for exploration, this is simply the rv modeling of the
actions used to compute probabilities of sampled actions and
the likes
"""
def __init__(self, states, actions_distribution=None, name="actor"):
self.state_shape = states.shape[1:]
state = T.Placeholder((1,) + states.shape[1:], "float32")
self.actions_distribution = actions_distribution
with symjax.Scope(name):
if actions_distribution == symjax.probabilities.Normal:
means, covs = self.create_network(states)
actions = actions_distribution(means, cov=covs)
samples = actions.sample()
samples_log_prob = actions.log_prob(samples)
action = symjax.probabilities.MultivariateNormal(
means.clone({states: state}),
cov=covs.clone({states: state}),
)
sample = self.action.sample()
sample_log_prob = self.action.log_prob(sample)
self._get_actions = symjax.function(
states, outputs=[samples, samples_log_prob]
)
self._get_action = symjax.function(
state,
outputs=[sample[0], sample_log_prob[0]],
)
elif actions_distribution is None:
actions = self.create_network(states)
action = actions.clone({states: state})
self._get_actions = symjax.function(states, outputs=actions)
self._get_action = symjax.function(state, outputs=action[0])
self._params = symjax.get_variables(
trainable=None, scope=symjax.current_graph().scope_name
)
self.actions = actions
self.state = state
self.action = action
def params(self, trainable):
if trainable is None:
return self._params
return [p for p in self._params if p.trainable == trainable]
def get_action(self, state):
if state.ndim == len(self.state_shape):
state = state[np.newaxis, :]
if not hasattr(self, "_get_action"):
raise RuntimeError("actor not well initialized")
if self.actions_distribution is None:
return self._get_action(state), {}
else:
a, probs = self._get_action(state)
return a, {"log_probs": probs}
def get_actions(self, state):
if not hasattr(self, "_get_actions"):
raise RuntimeError("actor not well initialized")
if self.actions_distribution is None:
return self._get_actions(state), {}
else:
a, probs = self._get_actions(state)
return a, {"log_probs": probs}
def create_network(self, states, action_shape):
"""creating of the actor network returning the actions
This method has to be implemented by the user in a own actor class
inheriting from `symjax.rl.Actor`. This method should take
two arguments, the states and the action_dim, and return
the actions after a possible nonlinear transformation of the given
states by say a deep networks
Parameters:
-----------
states: Tensor
the states with shape (batch_size, *state_shape)
action_shape: tuple or list
the shape of a (single) action, for example in classical
pendulum this would be `(2,)`.
Returns:
--------
actions: Tensor
the actions with shape (batch_size, *action_shape)
"""
raise RuntimeError("Not implemented, user should define its own")
[docs]class Critic(object):
def __init__(self, states, actions=None):
self.state_shape = states.shape[1:]
state = T.Placeholder((1,) + states.shape[1:], "float32", name="critic_state")
if actions:
self.action_shape = actions.shape[1:]
action = T.Placeholder(
(1,) + actions.shape[1:], "float32", name="critic_action"
)
action_shape = action.shape[1:]
with symjax.Scope("critic"):
q_values = self.create_network(states, actions)
if q_values.ndim == 2:
assert q_values.shape[1] == 1
q_values = q_values[:, 0]
q_value = q_values.clone({states: state, actions: action})
self._params = symjax.get_variables(
trainable=None, scope=symjax.current_graph().scope_name
)
inputs = [states, actions]
input = [state, action]
self.actions = actions
self.action = action
else:
with symjax.Scope("critic"):
q_values = self.create_network(states)
if q_values.ndim == 2:
assert q_values.shape[1] == 1
q_values = q_values[:, 0]
q_value = q_values.clone({states: state})
self._params = symjax.get_variables(
trainable=None, scope=symjax.current_graph().scope_name
)
inputs = [states]
input = [state]
self.q_values = q_values
self.state = state
self.states = states
self._get_q_values = symjax.function(*inputs, outputs=q_values)
self._get_q_value = symjax.function(*input, outputs=q_value[0])
def params(self, trainable):
if trainable is None:
return self._params
return [p for p in self._params if p.trainable == trainable]
def get_q_value(self, state, action=None):
if state.ndim == len(self.state_shape):
state = state[np.newaxis, :]
if action is not None:
if action.ndim == len(self.action_shape):
action = action[np.newaxis, :]
if not hasattr(self, "_get_q_value"):
raise RuntimeError("critic not well initialized")
if action is None:
return self._get_q_value(state)
else:
return self._get_q_value(state, action)
def get_q_values(self, states, actions=None):
if not hasattr(self, "_get_q_values"):
raise RuntimeError("critic not well initialized")
if actions is not None:
return self._get_q_values(states, actions)
else:
return self._get_q_values(states)
def create_network(self, states, actions=None):
raise RuntimeError("Not implemented, user should define its own")
class OrnsteinUhlenbeckProcess:
"""dXt = theta*(mu-Xt)*dt + sigma*dWt"""
def __init__(
self,
mean=0.0,
std_dev=0.2,
theta=0.15,
dt=1e-2,
noise_decay=0.99,
initial_noise_scale=1,
init=None,
):
self.theta = theta
self.mean = mean
self.std_dev = std_dev
self.dt = (dt,)
self.init = init
self.noise_decay = noise_decay
self.initial_noise_scale = initial_noise_scale
self.end_episode()
def __call__(self, action, episode):
with symjax.Scope("OUProcess"):
self.episode = T.Variable(1, "float32", name="episode", trainable=False)
self.noise_scale = self.initial_noise_scale * self.noise_decay ** episode
x = (
self.process
+ self.theta * (self.mean - self.process) * self.dt
+ self.std_dev * np.sqrt(self.dt) * np.random.normal(size=action.shape)
)
# Store x into process
# Makes next noise dependent on current one
self.process = x
return action + self.noise_scale * self.process
def end_episode(self):
if self.init is None:
self.process = np.zeros(1)
else:
self.process = self.init