mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[RLlib] Correctly get bytes size of SampleBatch (#14801)
This commit is contained in:
parent
b84575c092
commit
93d4244d9c
2 changed files with 7 additions and 1 deletions
|
@ -74,6 +74,10 @@ class TestSAC(unittest.TestCase):
|
||||||
config["prioritized_replay"] = True
|
config["prioritized_replay"] = True
|
||||||
config["rollout_fragment_length"] = 10
|
config["rollout_fragment_length"] = 10
|
||||||
config["train_batch_size"] = 10
|
config["train_batch_size"] = 10
|
||||||
|
# If we use default buffer size (1e6), the buffer will take up
|
||||||
|
# 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021)
|
||||||
|
# available system memory (8.34816 GB).
|
||||||
|
config["buffer_size"] = 40000
|
||||||
num_iterations = 1
|
num_iterations = 1
|
||||||
|
|
||||||
ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel)
|
ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel)
|
||||||
|
|
|
@ -430,7 +430,9 @@ class SampleBatch(dict):
|
||||||
Returns:
|
Returns:
|
||||||
int: The overall size in bytes of the data buffer (all columns).
|
int: The overall size in bytes of the data buffer (all columns).
|
||||||
"""
|
"""
|
||||||
return sum(sys.getsizeof(d) for d in self.values())
|
return sum(
|
||||||
|
v.nbytes if isinstance(v, np.ndarray) else sys.getsizeof(v)
|
||||||
|
for v in self.values())
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Reference in a new issue