ray/rllib/contrib/maddpg/maddpg_policy.py

381 lines
14 KiB
Python
Raw Normal View History

2019-08-06 19:22:06 -04:00
import ray
[RLlib] DQN torch version. (#7597) * Fix. * Rollback. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * Fix. * Fix. * Fix. * Fix. * WIP. * WIP. * Fix. * Test case fixes. * Test case fixes and LINT. * Test case fixes and LINT. * Rollback. * WIP. * WIP. * Test case fixes. * Fix. * Fix. * Fix. * Add regression test for DQN w/ param noise. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Fixes and LINT. * Comment * Regression test case. * WIP. * WIP. * LINT. * LINT. * WIP. * Fix. * Fix. * Fix. * LINT. * Fix (SAC does currently not support eager). * Fix. * WIP. * LINT. * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/evaluation/sampler.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/utils/exploration/exploration.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * WIP. * Fix. * LINT. * LINT. * Fix and LINT. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Fix. * Fix and LINT. * Update rllib/utils/exploration/exploration.py * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Update rllib/policy/dynamic_tf_policy.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Fixes. * WIP. * LINT. * Fixes and LINT. * LINT and fixes. * LINT. * Move action_dist back into torch extra_action_out_fn and LINT. * Working SimpleQ learning cartpole on both torch AND tf. * Working Rainbow learning cartpole on tf. * Working Rainbow learning cartpole on tf. * WIP. * LINT. * LINT. * Update docs and add torch to APEX test. * LINT. * Fix. * LINT. * Fix. * Fix. * Fix and docstrings. * Fix broken RLlib tests in master. * Split BAZEL learning tests into cartpole and pendulum (reached the 60min barrier). * Fix error_outputs option in BAZEL for RLlib regression tests. * Fix. * Tune param-noise tests. * LINT. * Fix. * Fix. * test * test * test * Fix. * Fix. * WIP. * WIP. * WIP. * WIP. * LINT. * WIP. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-04-06 20:56:16 +02:00
from ray.rllib.agents.dqn.dqn_tf_policy import minimize_and_clip, _adjust_nstep
2019-08-06 19:22:06 -04:00
from ray.rllib.evaluation.metrics import LEARNER_STATS_KEY
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.annotations import override
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.tf_policy import TFPolicy
2019-10-07 14:25:16 -07:00
from ray.rllib.utils import try_import_tf, try_import_tfp
2019-08-06 19:22:06 -04:00
import logging
from gym.spaces import Box, Discrete
import numpy as np
logger = logging.getLogger(__name__)
tf = try_import_tf()
2019-10-07 14:25:16 -07:00
tfp = try_import_tfp()
2019-08-06 19:22:06 -04:00
class MADDPGPostprocessing:
2019-08-06 19:22:06 -04:00
"""Implements agentwise termination signal and n-step learning."""
@override(Policy)
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
# FIXME: Get done from info is required since agentwise done is not
# supported now.
sample_batch.data["dones"] = self.get_done_from_info(
sample_batch.data["infos"])
# N-step Q adjustments
if self.config["n_step"] > 1:
_adjust_nstep(self.config["n_step"], self.config["gamma"],
sample_batch[SampleBatch.CUR_OBS],
sample_batch[SampleBatch.ACTIONS],
sample_batch[SampleBatch.REWARDS],
sample_batch[SampleBatch.NEXT_OBS],
sample_batch[SampleBatch.DONES])
return sample_batch
class MADDPGTFPolicy(MADDPGPostprocessing, TFPolicy):
def __init__(self, obs_space, act_space, config):
# _____ Initial Configuration
config = dict(ray.rllib.contrib.maddpg.DEFAULT_CONFIG, **config)
self.config = config
2019-08-06 19:22:06 -04:00
self.global_step = tf.train.get_or_create_global_step()
# FIXME: Get done from info is required since agentwise done is not
# supported now.
self.get_done_from_info = np.vectorize(
lambda info: info.get("done", False))
agent_id = config["agent_id"]
if agent_id is None:
raise ValueError("Must set `agent_id` in the policy config.")
if type(agent_id) is not int:
raise ValueError("Agent ids must be integers for MADDPG.")
# _____ Environment Setting
def _make_continuous_space(space):
if isinstance(space, Box):
return space
elif isinstance(space, Discrete):
return Box(
low=np.zeros((space.n, )), high=np.ones((space.n, )))
else:
raise UnsupportedSpaceException(
"Space {} is not supported.".format(space))
obs_space_n = [
_make_continuous_space(space) for _, (_, space, _, _) in
sorted(config["multiagent"]["policies"].items())
2019-08-06 19:22:06 -04:00
]
act_space_n = [
_make_continuous_space(space) for _, (_, _, space, _) in
sorted(config["multiagent"]["policies"].items())
2019-08-06 19:22:06 -04:00
]
# _____ Placeholders
# Placeholders for policy evaluation and updates
def _make_ph_n(space_n, name=""):
return [
tf.placeholder(
tf.float32,
shape=(None, ) + space.shape,
name=name + "_%d" % i) for i, space in enumerate(space_n)
]
obs_ph_n = _make_ph_n(obs_space_n, "obs")
act_ph_n = _make_ph_n(act_space_n, "actions")
new_obs_ph_n = _make_ph_n(obs_space_n, "new_obs")
new_act_ph_n = _make_ph_n(act_space_n, "new_actions")
rew_ph = tf.placeholder(
tf.float32, shape=None, name="rewards_{}".format(agent_id))
done_ph = tf.placeholder(
tf.float32, shape=None, name="dones_{}".format(agent_id))
if config["use_local_critic"]:
obs_space_n, act_space_n = [obs_space_n[agent_id]], [
act_space_n[agent_id]
]
obs_ph_n, act_ph_n = [obs_ph_n[agent_id]], [act_ph_n[agent_id]]
new_obs_ph_n, new_act_ph_n = [new_obs_ph_n[agent_id]], [
new_act_ph_n[agent_id]
]
agent_id = 0
# _____ Value Network
# Build critic network for t.
critic, _, critic_model_n, critic_vars = self._build_critic_network(
obs_ph_n,
act_ph_n,
obs_space_n,
act_space_n,
config["use_state_preprocessor"],
config["critic_hiddens"],
getattr(tf.nn, config["critic_hidden_activation"]),
2019-08-06 19:22:06 -04:00
scope="critic")
# Build critic network for t + 1.
target_critic, _, _, target_critic_vars = self._build_critic_network(
new_obs_ph_n,
new_act_ph_n,
obs_space_n,
act_space_n,
config["use_state_preprocessor"],
config["critic_hiddens"],
getattr(tf.nn, config["critic_hidden_activation"]),
2019-08-06 19:22:06 -04:00
scope="target_critic")
# Build critic loss.
td_error = tf.subtract(
tf.stop_gradient(
rew_ph + (1.0 - done_ph) *
(config["gamma"]**config["n_step"]) * target_critic[:, 0]),
critic[:, 0])
critic_loss = tf.reduce_mean(td_error**2)
# _____ Policy Network
# Build actor network for t.
act_sampler, actor_feature, actor_model, actor_vars = (
self._build_actor_network(
obs_ph_n[agent_id],
obs_space_n[agent_id],
act_space_n[agent_id],
config["use_state_preprocessor"],
config["actor_hiddens"],
getattr(tf.nn, config["actor_hidden_activation"]),
2019-08-06 19:22:06 -04:00
scope="actor"))
# Build actor network for t + 1.
self.new_obs_ph = new_obs_ph_n[agent_id]
self.target_act_sampler, _, _, target_actor_vars = (
self._build_actor_network(
self.new_obs_ph,
obs_space_n[agent_id],
act_space_n[agent_id],
config["use_state_preprocessor"],
config["actor_hiddens"],
getattr(tf.nn, config["actor_hidden_activation"]),
2019-08-06 19:22:06 -04:00
scope="target_actor"))
# Build actor loss.
act_n = act_ph_n.copy()
act_n[agent_id] = act_sampler
critic, _, _, _ = self._build_critic_network(
obs_ph_n,
act_n,
obs_space_n,
act_space_n,
config["use_state_preprocessor"],
config["critic_hiddens"],
getattr(tf.nn, config["critic_hidden_activation"]),
2019-08-06 19:22:06 -04:00
scope="critic")
actor_loss = -tf.reduce_mean(critic)
if config["actor_feature_reg"] is not None:
actor_loss += config["actor_feature_reg"] * tf.reduce_mean(
actor_feature**2)
# _____ Losses
self.losses = {"critic": critic_loss, "actor": actor_loss}
# _____ Optimizers
self.optimizers = {
"critic": tf.train.AdamOptimizer(config["critic_lr"]),
"actor": tf.train.AdamOptimizer(config["actor_lr"])
}
# _____ Build variable update ops.
self.tau = tf.placeholder_with_default(
config["tau"], shape=(), name="tau")
def _make_target_update_op(vs, target_vs, tau):
return [
target_v.assign(tau * v + (1.0 - tau) * target_v)
for v, target_v in zip(vs, target_vs)
]
self.update_target_vars = _make_target_update_op(
critic_vars + actor_vars, target_critic_vars + target_actor_vars,
self.tau)
def _make_set_weight_op(variables):
vs = list()
for v in variables.values():
vs += v
phs = [
tf.placeholder(
tf.float32,
shape=v.get_shape(),
name=v.name.split(":")[0] + "_ph") for v in vs
]
return tf.group(*[v.assign(ph) for v, ph in zip(vs, phs)]), phs
self.vars = {
"critic": critic_vars,
"actor": actor_vars,
"target_critic": target_critic_vars,
"target_actor": target_actor_vars
}
self.update_vars, self.vars_ph = _make_set_weight_op(self.vars)
# _____ TensorFlow Initialization
self.sess = tf.get_default_session()
def _make_loss_inputs(placeholders):
return [(ph.name.split("/")[-1].split(":")[0], ph)
for ph in placeholders]
loss_inputs = _make_loss_inputs(obs_ph_n + act_ph_n + new_obs_ph_n +
new_act_ph_n + [rew_ph, done_ph])
TFPolicy.__init__(
self,
obs_space,
act_space,
config=config,
sess=self.sess,
2019-08-06 19:22:06 -04:00
obs_input=obs_ph_n[agent_id],
[RLlib] Policy.compute_log_likelihoods() and SAC refactor. (issue #7107) (#7124) * Exploration API (+EpsilonGreedy sub-class). * Exploration API (+EpsilonGreedy sub-class). * Cleanup/LINT. * Add `deterministic` to generic Trainer config (NOTE: this is still ignored by most Agents). * Add `error` option to deprecation_warning(). * WIP. * Bug fix: Get exploration-info for tf framework. Bug fix: Properly deprecate some DQN config keys. * WIP. * LINT. * WIP. * Split PerWorkerEpsilonGreedy out of EpsilonGreedy. Docstrings. * Fix bug in sampler.py in case Policy has self.exploration = None * Update rllib/agents/dqn/dqn.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Update rllib/agents/trainer.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Change requests. * LINT * In tune/utils/util.py::deep_update() Only keep deep_updat'ing if both original and value are dicts. If value is not a dict, set * Completely obsolete syn_replay_optimizer.py's parameters schedule_max_timesteps AND beta_annealing_fraction (replaced with prioritized_replay_beta_annealing_timesteps). * Update rllib/evaluation/worker_set.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Review fixes. * Fix default value for DQN's exploration spec. * LINT * Fix recursion bug (wrong parent c'tor). * Do not pass timestep to get_exploration_info. * Update tf_policy.py * Fix some remaining issues with test cases and remove more deprecated DQN/APEX exploration configs. * Bug fix tf-action-dist * DDPG incompatibility bug fix with new DQN exploration handling (which is imported by DDPG). * Switch off exploration when getting action probs from off-policy-estimator's policy. * LINT * Fix test_checkpoint_restore.py. * Deprecate all SAC exploration (unused) configs. * Properly use `model.last_output()` everywhere. Instead of `model._last_output`. * WIP. * Take out set_epsilon from multi-agent-env test (not needed, decays anyway). * WIP. * Trigger re-test (flaky checkpoint-restore test). * WIP. * WIP. * Add test case for deterministic action sampling in PPO. * bug fix. * Added deterministic test cases for different Agents. * Fix problem with TupleActions in dynamic-tf-policy. * Separate supported_spaces tests so they can be run separately for easier debugging. * LINT. * Fix autoregressive_action_dist.py test case. * Re-test. * Fix. * Remove duplicate py_test rule from bazel. * LINT. * WIP. * WIP. * SAC fix. * SAC fix. * WIP. * WIP. * WIP. * FIX 2 examples tests. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Renamed test file. * WIP. * Add unittest.main. * Make action_dist_class mandatory. * fix * FIX. * WIP. * WIP. * Fix. * Fix. * Fix explorations test case (contextlib cannot find its own nullcontext??). * Force torch to be installed for QMIX. * LINT. * Fix determine_tests_to_run.py. * Fix determine_tests_to_run.py. * WIP * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Rename some stuff. * Rename some stuff. * WIP. * WIP. * Fix SAC. * Fix SAC. * Fix strange tf-error in ray core tests. * Fix strange ray-core tf-error in test_memory_scheduling test case. * Fix test_io.py. * LINT. * Update SAC yaml files' config. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-02-22 23:19:49 +01:00
sampled_action=act_sampler,
2019-08-06 19:22:06 -04:00
loss=actor_loss + critic_loss,
loss_inputs=loss_inputs,
dist_inputs=actor_feature)
2019-08-06 19:22:06 -04:00
self.sess.run(tf.global_variables_initializer())
# Hard initial update
self.update_target(1.0)
@override(TFPolicy)
def optimizer(self):
return None
@override(TFPolicy)
def gradients(self, optimizer, loss):
if self.config["grad_norm_clipping"] is not None:
self.gvs = {
k: minimize_and_clip(optimizer, self.losses[k], self.vars[k],
self.config["grad_norm_clipping"])
for k, optimizer in self.optimizers.items()
}
else:
self.gvs = {
k: optimizer.compute_gradients(self.losses[k], self.vars[k])
for k, optimizer in self.optimizers.items()
}
return self.gvs["critic"] + self.gvs["actor"]
@override(TFPolicy)
def build_apply_op(self, optimizer, grads_and_vars):
critic_apply_op = self.optimizers["critic"].apply_gradients(
self.gvs["critic"])
with tf.control_dependencies([tf.assign_add(self.global_step, 1)]):
with tf.control_dependencies([critic_apply_op]):
actor_apply_op = self.optimizers["actor"].apply_gradients(
self.gvs["actor"])
return actor_apply_op
@override(TFPolicy)
def extra_compute_action_feed_dict(self):
return {}
@override(TFPolicy)
def extra_compute_grad_fetches(self):
return {LEARNER_STATS_KEY: {}}
@override(TFPolicy)
def get_weights(self):
var_list = []
for var in self.vars.values():
var_list += var
return self.sess.run(var_list)
@override(TFPolicy)
def set_weights(self, weights):
self.sess.run(
self.update_vars, feed_dict=dict(zip(self.vars_ph, weights)))
@override(Policy)
def get_state(self):
return TFPolicy.get_state(self)
@override(Policy)
def set_state(self, state):
TFPolicy.set_state(self, state)
def _build_critic_network(self,
obs_n,
act_n,
obs_space_n,
act_space_n,
use_state_preprocessor,
2019-08-06 19:22:06 -04:00
hiddens,
activation=None,
scope=None):
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE) as scope:
if use_state_preprocessor:
2019-08-06 19:22:06 -04:00
model_n = [
ModelCatalog.get_model({
"obs": obs,
"is_training": self._get_is_training_placeholder(),
}, obs_space, act_space, 1, self.config["model"])
for obs, obs_space, act_space in zip(
obs_n, obs_space_n, act_space_n)
]
out_n = [model.last_layer for model in model_n]
out = tf.concat(out_n + act_n, axis=1)
else:
model_n = [None] * len(obs_n)
out = tf.concat(obs_n + act_n, axis=1)
for hidden in hiddens:
[RLlib] Policy.compute_log_likelihoods() and SAC refactor. (issue #7107) (#7124) * Exploration API (+EpsilonGreedy sub-class). * Exploration API (+EpsilonGreedy sub-class). * Cleanup/LINT. * Add `deterministic` to generic Trainer config (NOTE: this is still ignored by most Agents). * Add `error` option to deprecation_warning(). * WIP. * Bug fix: Get exploration-info for tf framework. Bug fix: Properly deprecate some DQN config keys. * WIP. * LINT. * WIP. * Split PerWorkerEpsilonGreedy out of EpsilonGreedy. Docstrings. * Fix bug in sampler.py in case Policy has self.exploration = None * Update rllib/agents/dqn/dqn.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Update rllib/agents/trainer.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Change requests. * LINT * In tune/utils/util.py::deep_update() Only keep deep_updat'ing if both original and value are dicts. If value is not a dict, set * Completely obsolete syn_replay_optimizer.py's parameters schedule_max_timesteps AND beta_annealing_fraction (replaced with prioritized_replay_beta_annealing_timesteps). * Update rllib/evaluation/worker_set.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Review fixes. * Fix default value for DQN's exploration spec. * LINT * Fix recursion bug (wrong parent c'tor). * Do not pass timestep to get_exploration_info. * Update tf_policy.py * Fix some remaining issues with test cases and remove more deprecated DQN/APEX exploration configs. * Bug fix tf-action-dist * DDPG incompatibility bug fix with new DQN exploration handling (which is imported by DDPG). * Switch off exploration when getting action probs from off-policy-estimator's policy. * LINT * Fix test_checkpoint_restore.py. * Deprecate all SAC exploration (unused) configs. * Properly use `model.last_output()` everywhere. Instead of `model._last_output`. * WIP. * Take out set_epsilon from multi-agent-env test (not needed, decays anyway). * WIP. * Trigger re-test (flaky checkpoint-restore test). * WIP. * WIP. * Add test case for deterministic action sampling in PPO. * bug fix. * Added deterministic test cases for different Agents. * Fix problem with TupleActions in dynamic-tf-policy. * Separate supported_spaces tests so they can be run separately for easier debugging. * LINT. * Fix autoregressive_action_dist.py test case. * Re-test. * Fix. * Remove duplicate py_test rule from bazel. * LINT. * WIP. * WIP. * SAC fix. * SAC fix. * WIP. * WIP. * WIP. * FIX 2 examples tests. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Renamed test file. * WIP. * Add unittest.main. * Make action_dist_class mandatory. * fix * FIX. * WIP. * WIP. * Fix. * Fix. * Fix explorations test case (contextlib cannot find its own nullcontext??). * Force torch to be installed for QMIX. * LINT. * Fix determine_tests_to_run.py. * Fix determine_tests_to_run.py. * WIP * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Rename some stuff. * Rename some stuff. * WIP. * WIP. * Fix SAC. * Fix SAC. * Fix strange tf-error in ray core tests. * Fix strange ray-core tf-error in test_memory_scheduling test case. * Fix test_io.py. * LINT. * Update SAC yaml files' config. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-02-22 23:19:49 +01:00
out = tf.layers.dense(out, units=hidden, activation=activation)
2019-08-06 19:22:06 -04:00
feature = out
out = tf.layers.dense(feature, units=1, activation=None)
return out, feature, model_n, tf.global_variables(scope.name)
def _build_actor_network(self,
obs,
obs_space,
act_space,
use_state_preprocessor,
2019-08-06 19:22:06 -04:00
hiddens,
activation=None,
scope=None):
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE) as scope:
if use_state_preprocessor:
2019-08-06 19:22:06 -04:00
model = ModelCatalog.get_model({
"obs": obs,
"is_training": self._get_is_training_placeholder(),
}, obs_space, act_space, 1, self.config["model"])
out = model.last_layer
else:
model = None
out = obs
for hidden in hiddens:
[RLlib] Policy.compute_log_likelihoods() and SAC refactor. (issue #7107) (#7124) * Exploration API (+EpsilonGreedy sub-class). * Exploration API (+EpsilonGreedy sub-class). * Cleanup/LINT. * Add `deterministic` to generic Trainer config (NOTE: this is still ignored by most Agents). * Add `error` option to deprecation_warning(). * WIP. * Bug fix: Get exploration-info for tf framework. Bug fix: Properly deprecate some DQN config keys. * WIP. * LINT. * WIP. * Split PerWorkerEpsilonGreedy out of EpsilonGreedy. Docstrings. * Fix bug in sampler.py in case Policy has self.exploration = None * Update rllib/agents/dqn/dqn.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Update rllib/agents/trainer.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * WIP. * Change requests. * LINT * In tune/utils/util.py::deep_update() Only keep deep_updat'ing if both original and value are dicts. If value is not a dict, set * Completely obsolete syn_replay_optimizer.py's parameters schedule_max_timesteps AND beta_annealing_fraction (replaced with prioritized_replay_beta_annealing_timesteps). * Update rllib/evaluation/worker_set.py Co-Authored-By: Eric Liang <ekhliang@gmail.com> * Review fixes. * Fix default value for DQN's exploration spec. * LINT * Fix recursion bug (wrong parent c'tor). * Do not pass timestep to get_exploration_info. * Update tf_policy.py * Fix some remaining issues with test cases and remove more deprecated DQN/APEX exploration configs. * Bug fix tf-action-dist * DDPG incompatibility bug fix with new DQN exploration handling (which is imported by DDPG). * Switch off exploration when getting action probs from off-policy-estimator's policy. * LINT * Fix test_checkpoint_restore.py. * Deprecate all SAC exploration (unused) configs. * Properly use `model.last_output()` everywhere. Instead of `model._last_output`. * WIP. * Take out set_epsilon from multi-agent-env test (not needed, decays anyway). * WIP. * Trigger re-test (flaky checkpoint-restore test). * WIP. * WIP. * Add test case for deterministic action sampling in PPO. * bug fix. * Added deterministic test cases for different Agents. * Fix problem with TupleActions in dynamic-tf-policy. * Separate supported_spaces tests so they can be run separately for easier debugging. * LINT. * Fix autoregressive_action_dist.py test case. * Re-test. * Fix. * Remove duplicate py_test rule from bazel. * LINT. * WIP. * WIP. * SAC fix. * SAC fix. * WIP. * WIP. * WIP. * FIX 2 examples tests. * WIP. * WIP. * WIP. * WIP. * WIP. * Fix. * LINT. * Renamed test file. * WIP. * Add unittest.main. * Make action_dist_class mandatory. * fix * FIX. * WIP. * WIP. * Fix. * Fix. * Fix explorations test case (contextlib cannot find its own nullcontext??). * Force torch to be installed for QMIX. * LINT. * Fix determine_tests_to_run.py. * Fix determine_tests_to_run.py. * WIP * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Add Random exploration component to tests (fixed issue with "static-graph randomness" via py_function). * Rename some stuff. * Rename some stuff. * WIP. * WIP. * Fix SAC. * Fix SAC. * Fix strange tf-error in ray core tests. * Fix strange ray-core tf-error in test_memory_scheduling test case. * Fix test_io.py. * LINT. * Update SAC yaml files' config. Co-authored-by: Eric Liang <ekhliang@gmail.com>
2020-02-22 23:19:49 +01:00
out = tf.layers.dense(out, units=hidden, activation=activation)
2019-08-06 19:22:06 -04:00
feature = tf.layers.dense(
out, units=act_space.shape[0], activation=None)
2019-10-07 14:25:16 -07:00
sampler = tfp.distributions.RelaxedOneHotCategorical(
2019-08-06 19:22:06 -04:00
temperature=1.0, logits=feature).sample()
return sampler, feature, model, tf.global_variables(scope.name)
def update_target(self, tau=None):
if tau is not None:
self.sess.run(self.update_target_vars, {self.tau: tau})
else:
self.sess.run(self.update_target_vars)