[RLlib] Remove requirement for dataclasses in rllib (not supported in py3.5) (#9237)

This commit is contained in:
Sven Mika 2020-07-01 17:31:44 +02:00 committed by GitHub
parent c11855728a
commit b4c0b942fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 31 additions and 30 deletions

View file

@ -4,7 +4,7 @@ import ray
import ray.experimental.tf_utils
from ray.rllib.utils.framework import try_import_tf
tf = try_import_tf()
tf, _, _ = try_import_tf()
def make_linear_network(w_name=None, b_name=None):

View file

@ -9,13 +9,12 @@ from ray.rllib.agents.ddpg.ddpg_torch_policy import ddpg_actor_critic_loss as \
from ray.rllib.agents.sac.tests.test_sac import SimpleEnv
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.numpy import fc, huber_loss, l2_loss, relu, sigmoid
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
framework_iterator
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
tf = try_import_tf()
torch, _ = try_import_torch()

View file

@ -2,12 +2,9 @@ import unittest
import ray
import ray.rllib.agents.maml as maml
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check_compute_single_action, \
framework_iterator
tf = try_import_tf()
class TestMAML(unittest.TestCase):
@classmethod

View file

@ -1,11 +1,9 @@
from dataclasses import dataclass
import numpy as np
from typing import Dict
from typing import Dict, Optional
from ray.rllib.utils.types import TensorType
@dataclass
class ViewRequirement:
"""Single view requirement (for one column in a ModelV2 input_dict).
@ -24,29 +22,39 @@ class ViewRequirement:
>>> print(req)
{"obs": ViewRequirement(timesteps=0)}
"""
# The data column name from the SampleBatch (str key).
# If None, use the dict key under which this ViewRequirement resides.
data_col: str = None
# List of relative (or absolute timesteps) to be present in the
# input_dict.
timesteps: int = 0
def __init__(self,
data_col: Optional[str] = None,
timesteps: int = 0,
fill_mode: str = "zeros",
repeat_mode: str = "all"):
"""Initializes a ViewRequirement object.
# Switch on absolute timestep mode. Default: False.
# TODO: (sven)
# "absolute_timesteps",
Args:
data_col (): The data column name from the SampleBatch (str key).
If None, use the dict key under which this ViewRequirement
resides.
timesteps (Union[List[int], int]): List of relative (or absolute
timesteps) to be present in the input_dict.
fill_mode (str): The fill mode in case t<0 or t>H.
One of "zeros", "tile".
repeat_mode (str): The repeat-mode (one of "all" or "only_first").
E.g. for training, we only want the first internal state
timestep (the NN will calculate all others again anyways).
"""
self.data_col = data_col
self.timesteps = timesteps
# The fill mode in case t<0 or t>H: One of "zeros", "tile".
fill_mode: str = "zeros"
# Switch on absolute timestep mode. Default: False.
# TODO: (sven)
# "absolute_timesteps",
# The repeat-mode (one of "all" or "only_first"). E.g. for training,
# we only want the first internal state timestep (the NN will
# calculate all others again anyways).
repeat_mode: str = "all"
self.fill_mode = fill_mode
self.repeat_mode = repeat_mode
# Provide all data as time major (default: False).
# TODO: (sven)
# "time_major",
# Provide all data as time major (default: False).
# TODO: (sven)
# "time_major",
def get_trajectory_view(

View file

@ -25,11 +25,8 @@ import yaml
import ray
from ray import tune
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check_learning_achieved
tf = try_import_tf()
parser = argparse.ArgumentParser()
parser.add_argument("--run", type=str, default=None)
parser.add_argument("--torch", action="store_true")