mirror of
https://github.com/vale981/ray
synced 2025-03-05 18:11:42 -05:00
[RLlib] Remove requirement for dataclasses in rllib (not supported in py3.5) (#9237)
This commit is contained in:
parent
c11855728a
commit
b4c0b942fe
5 changed files with 31 additions and 30 deletions
|
@ -4,7 +4,7 @@ import ray
|
|||
import ray.experimental.tf_utils
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
tf, _, _ = try_import_tf()
|
||||
|
||||
|
||||
def make_linear_network(w_name=None, b_name=None):
|
||||
|
|
|
@ -9,13 +9,12 @@ from ray.rllib.agents.ddpg.ddpg_torch_policy import ddpg_actor_critic_loss as \
|
|||
from ray.rllib.agents.sac.tests.test_sac import SimpleEnv
|
||||
from ray.rllib.execution.replay_buffer import LocalReplayBuffer
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.numpy import fc, huber_loss, l2_loss, relu, sigmoid
|
||||
from ray.rllib.utils.test_utils import check, check_compute_single_action, \
|
||||
framework_iterator
|
||||
from ray.rllib.utils.torch_ops import convert_to_torch_tensor
|
||||
|
||||
tf = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
|
||||
|
||||
|
|
|
@ -2,12 +2,9 @@ import unittest
|
|||
|
||||
import ray
|
||||
import ray.rllib.agents.maml as maml
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check_compute_single_action, \
|
||||
framework_iterator
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class TestMAML(unittest.TestCase):
|
||||
@classmethod
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from ray.rllib.utils.types import TensorType
|
||||
|
||||
|
||||
@dataclass
|
||||
class ViewRequirement:
|
||||
"""Single view requirement (for one column in a ModelV2 input_dict).
|
||||
|
||||
|
@ -24,29 +22,39 @@ class ViewRequirement:
|
|||
>>> print(req)
|
||||
{"obs": ViewRequirement(timesteps=0)}
|
||||
"""
|
||||
# The data column name from the SampleBatch (str key).
|
||||
# If None, use the dict key under which this ViewRequirement resides.
|
||||
data_col: str = None
|
||||
|
||||
# List of relative (or absolute timesteps) to be present in the
|
||||
# input_dict.
|
||||
timesteps: int = 0
|
||||
def __init__(self,
|
||||
data_col: Optional[str] = None,
|
||||
timesteps: int = 0,
|
||||
fill_mode: str = "zeros",
|
||||
repeat_mode: str = "all"):
|
||||
"""Initializes a ViewRequirement object.
|
||||
|
||||
# Switch on absolute timestep mode. Default: False.
|
||||
# TODO: (sven)
|
||||
# "absolute_timesteps",
|
||||
Args:
|
||||
data_col (): The data column name from the SampleBatch (str key).
|
||||
If None, use the dict key under which this ViewRequirement
|
||||
resides.
|
||||
timesteps (Union[List[int], int]): List of relative (or absolute
|
||||
timesteps) to be present in the input_dict.
|
||||
fill_mode (str): The fill mode in case t<0 or t>H.
|
||||
One of "zeros", "tile".
|
||||
repeat_mode (str): The repeat-mode (one of "all" or "only_first").
|
||||
E.g. for training, we only want the first internal state
|
||||
timestep (the NN will calculate all others again anyways).
|
||||
"""
|
||||
self.data_col = data_col
|
||||
self.timesteps = timesteps
|
||||
|
||||
# The fill mode in case t<0 or t>H: One of "zeros", "tile".
|
||||
fill_mode: str = "zeros"
|
||||
# Switch on absolute timestep mode. Default: False.
|
||||
# TODO: (sven)
|
||||
# "absolute_timesteps",
|
||||
|
||||
# The repeat-mode (one of "all" or "only_first"). E.g. for training,
|
||||
# we only want the first internal state timestep (the NN will
|
||||
# calculate all others again anyways).
|
||||
repeat_mode: str = "all"
|
||||
self.fill_mode = fill_mode
|
||||
self.repeat_mode = repeat_mode
|
||||
|
||||
# Provide all data as time major (default: False).
|
||||
# TODO: (sven)
|
||||
# "time_major",
|
||||
# Provide all data as time major (default: False).
|
||||
# TODO: (sven)
|
||||
# "time_major",
|
||||
|
||||
|
||||
def get_trajectory_view(
|
||||
|
|
|
@ -25,11 +25,8 @@ import yaml
|
|||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.rllib.utils.framework import try_import_tf
|
||||
from ray.rllib.utils.test_utils import check_learning_achieved
|
||||
|
||||
tf = try_import_tf()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--run", type=str, default=None)
|
||||
parser.add_argument("--torch", action="store_true")
|
||||
|
|
Loading…
Add table
Reference in a new issue