[RLlib] Correctly get bytes size of SampleBatch (#14801)

This commit is contained in:
Raphael CHEN 2021-03-31 01:24:58 +08:00 committed by GitHub
parent b84575c092
commit 93d4244d9c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 1 deletions

View file

@ -74,6 +74,10 @@ class TestSAC(unittest.TestCase):
config["prioritized_replay"] = True config["prioritized_replay"] = True
config["rollout_fragment_length"] = 10 config["rollout_fragment_length"] = 10
config["train_batch_size"] = 10 config["train_batch_size"] = 10
# If we use default buffer size (1e6), the buffer will take up
# 169.445 GB memory, which is beyond travis-ci's current (Mar 19, 2021)
# available system memory (8.34816 GB).
config["buffer_size"] = 40000
num_iterations = 1 num_iterations = 1
ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel) ModelCatalog.register_custom_model("batch_norm", KerasBatchNormModel)

View file

@ -430,7 +430,9 @@ class SampleBatch(dict):
Returns: Returns:
int: The overall size in bytes of the data buffer (all columns). int: The overall size in bytes of the data buffer (all columns).
""" """
return sum(sys.getsizeof(d) for d in self.values()) return sum(
v.nbytes if isinstance(v, np.ndarray) else sys.getsizeof(v)
for v in self.values())
def get(self, key, default=None): def get(self, key, default=None):
try: try: