mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[RLlib] Add HowTo set env seed to our custom env example script. (#14471)
This commit is contained in:
parent
897b84b300
commit
78a134efa2
1 changed files with 11 additions and 2 deletions
|
@ -12,10 +12,12 @@ import gym
|
|||
from gym.spaces import Discrete, Box
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.tune import grid_search
|
||||
from ray.rllib.env.env_context import EnvContext
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
|
||||
|
@ -41,12 +43,14 @@ class SimpleCorridor(gym.Env):
|
|||
|
||||
You can configure the length of the corridor via the env config."""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: EnvContext):
|
||||
self.end_pos = config["corridor_length"]
|
||||
self.cur_pos = 0
|
||||
self.action_space = Discrete(2)
|
||||
self.observation_space = Box(
|
||||
0.0, self.end_pos, shape=(1, ), dtype=np.float32)
|
||||
# Set the seed. This is only used for the final (reach goal) reward.
|
||||
self.seed(config.worker_index * config.num_workers)
|
||||
|
||||
def reset(self):
|
||||
self.cur_pos = 0
|
||||
|
@ -59,7 +63,12 @@ class SimpleCorridor(gym.Env):
|
|||
elif action == 1:
|
||||
self.cur_pos += 1
|
||||
done = self.cur_pos >= self.end_pos
|
||||
return [self.cur_pos], 1.0 if done else -0.1, done, {}
|
||||
# Produce a random reward when we reach the goal.
|
||||
return [self.cur_pos], \
|
||||
random.random() * 2 if done else -0.1, done, {}
|
||||
|
||||
def seed(self, seed=None):
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
class CustomModel(TFModelV2):
|
||||
|
|
Loading…
Add table
Reference in a new issue