2018-12-21 03:44:34 +09:00
|
|
|
import numpy as np
|
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
from ray.rllib.algorithms.algorithm import Algorithm, with_common_config
|
2018-12-21 03:44:34 +09:00
|
|
|
from ray.rllib.utils.annotations import override
|
2022-06-11 15:10:39 +02:00
|
|
|
from ray.rllib.utils.typing import AlgorithmConfigDict
|
2018-12-21 03:44:34 +09:00
|
|
|
|
|
|
|
|
2022-02-08 16:29:25 -08:00
|
|
|
# fmt: off
|
2018-12-21 03:44:34 +09:00
|
|
|
# __sphinx_doc_begin__
|
2022-06-11 15:10:39 +02:00
|
|
|
class RandomAgent(Algorithm):
|
|
|
|
"""Algo that produces random actions and never learns."""
|
2018-12-21 03:44:34 +09:00
|
|
|
|
2021-11-23 23:01:05 +01:00
|
|
|
@classmethod
|
2022-06-11 15:10:39 +02:00
|
|
|
@override(Algorithm)
|
|
|
|
def get_default_config(cls) -> AlgorithmConfigDict:
|
2021-11-23 23:01:05 +01:00
|
|
|
return with_common_config({
|
|
|
|
"rollouts_per_iteration": 10,
|
|
|
|
"framework": "tf", # not used
|
|
|
|
})
|
2018-12-21 03:44:34 +09:00
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
@override(Algorithm)
|
2019-03-29 12:44:23 -07:00
|
|
|
def _init(self, config, env_creator):
|
|
|
|
self.env = env_creator(config["env_config"])
|
2018-12-21 03:44:34 +09:00
|
|
|
|
2022-06-11 15:10:39 +02:00
|
|
|
@override(Algorithm)
|
2020-07-01 11:00:00 -07:00
|
|
|
def step(self):
|
2018-12-21 03:44:34 +09:00
|
|
|
rewards = []
|
|
|
|
steps = 0
|
|
|
|
for _ in range(self.config["rollouts_per_iteration"]):
|
|
|
|
obs = self.env.reset()
|
|
|
|
done = False
|
|
|
|
reward = 0.0
|
|
|
|
while not done:
|
|
|
|
action = self.env.action_space.sample()
|
|
|
|
obs, r, done, info = self.env.step(action)
|
|
|
|
reward += r
|
|
|
|
steps += 1
|
|
|
|
rewards.append(reward)
|
|
|
|
return {
|
|
|
|
"episode_reward_mean": np.mean(rewards),
|
|
|
|
"timesteps_this_iter": steps,
|
|
|
|
}
|
|
|
|
# __sphinx_doc_end__
|
2022-02-09 22:12:11 -08:00
|
|
|
# FIXME: We switched our code formatter from YAPF to Black. Check if we can enable code
|
|
|
|
# formatting on this module and update the comment below. See issue #21318.
|
2018-12-21 03:44:34 +09:00
|
|
|
# don't enable yapf after, it's buggy here
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2022-06-11 15:10:39 +02:00
|
|
|
algo = RandomAgent(
|
2018-12-21 03:44:34 +09:00
|
|
|
env="CartPole-v0", config={"rollouts_per_iteration": 10})
|
2022-06-11 15:10:39 +02:00
|
|
|
result = algo.train()
|
2018-12-21 03:44:34 +09:00
|
|
|
assert result["episode_reward_mean"] > 10, result
|
|
|
|
print("Test: OK")
|