ray/rllib/contrib/random_agent/random_agent.py
Sven 60d4d5e1aa Remove future imports (#6724)
* Remove all __future__ imports from RLlib.

* Remove (object) again from tf_run_builder.py::TFRunBuilder.

* Fix 2xLINT warnings.

* Fix broken appo_policy import (must be appo_tf_policy)

* Remove future imports from all other ray files (not just RLlib).

* Remove future imports from all other ray files (not just RLlib).

* Remove future import blocks that contain `unicode_literals` as well.
Revert appo_tf_policy.py to appo_policy.py (belongs to another PR).

* Add two empty lines before Schedule class.

* Put back __future__ imports into determine_tests_to_run.py. Fails otherwise on a py2/print related error.
2020-01-09 00:15:48 -08:00

48 lines
1.3 KiB
Python

import numpy as np
from ray.rllib.agents.trainer import Trainer, with_common_config
from ray.rllib.utils.annotations import override
# yapf: disable
# __sphinx_doc_begin__
class RandomAgent(Trainer):
"""Policy that takes random actions and never learns."""
_name = "RandomAgent"
_default_config = with_common_config({
"rollouts_per_iteration": 10,
})
@override(Trainer)
def _init(self, config, env_creator):
self.env = env_creator(config["env_config"])
@override(Trainer)
def _train(self):
rewards = []
steps = 0
for _ in range(self.config["rollouts_per_iteration"]):
obs = self.env.reset()
done = False
reward = 0.0
while not done:
action = self.env.action_space.sample()
obs, r, done, info = self.env.step(action)
reward += r
steps += 1
rewards.append(reward)
return {
"episode_reward_mean": np.mean(rewards),
"timesteps_this_iter": steps,
}
# __sphinx_doc_end__
# don't enable yapf after, it's buggy here
if __name__ == "__main__":
trainer = RandomAgent(
env="CartPole-v0", config={"rollouts_per_iteration": 10})
result = trainer.train()
assert result["episode_reward_mean"] > 10, result
print("Test: OK")