mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
parent
f79b826f31
commit
51dbd99a25
2 changed files with 3 additions and 3 deletions
|
@ -452,7 +452,7 @@ def load_checkpoint(
|
|||
with checkpoint.as_directory() as checkpoint_path:
|
||||
estimator_path = os.path.join(checkpoint_path, MODEL_KEY)
|
||||
with open(estimator_path, "rb") as f:
|
||||
estimator_path = cpickle.load(f)
|
||||
estimator = cpickle.load(f)
|
||||
preprocessor = load_preprocessor_from_dir(checkpoint_path)
|
||||
|
||||
return estimator_path, preprocessor
|
||||
return estimator, preprocessor
|
||||
|
|
|
@ -125,7 +125,7 @@ class TensorflowTrainer(DataParallelTrainer):
|
|||
)
|
||||
model.fit(tf_dataset)
|
||||
train.save_checkpoint(
|
||||
epoch=epoch, model_weights=model.get_weights())
|
||||
epoch=epoch, model=model.get_weights())
|
||||
|
||||
train_dataset = ray.data.from_items(
|
||||
[{"x": x, "y": x + 1} for x in range(32)])
|
||||
|
|
Loading…
Add table
Reference in a new issue