Co-authored-by: Eric Liang <ekhliang@gmail.com>
Co-authored-by: matthewdeng <matthew.j.deng@gmail.com>
Co-authored-by: Matthew Deng <matt@anyscale.com>
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
Signed-off-by: Amog Kamsetty <amogkamsetty@yahoo.com>
As discussed offline, allow configurability for feature columns and keep columns in BatchPredictor for better scoring UX on test datasets.
Updates TensorflowPredictor to use the new _predict_pandas API.
Also as agreed upon offline, removes the extra configurations from TensorflowPredictor (column selection, concatenation) in favor of having this be done via a Preprocessor.
Update documentation to use `session.report`.
Next steps:
1. Update our internal caller to use `session.report`. Most importantly, CheckpointManager and DataParallelTrainer.
2. Update `get_trial_resources` to use PGF notions to incorporate the requirement of ResourceChangingScheduler. @Yard1
3. After 2 is done, change all `tune.get_trial_resources` to `session.get_trial_resources`
4. [internal implementation] remove special checkpoint handling logic from huggingface trainer. Optimize the flow for checkpoint conversion with `session.report`.
Co-authored-by: Antoni Baum <antoni.baum@protonmail.com>
1. Update `DummyTrainer` to take `num_epochs` instead of `runtime_seconds`.
1. Ray Train expects equal number of calls to `train.report()`. Different workers may run at different speeds and terminate after different epoch numbers, which causes an error.
2. Add `generate_epochs` to support `DatasetPipeline` when `use_stream_api` is True.
3. Update `__main__` code to support testing different configurations.
This PR:
* Allows the user to set `keep_checkpoints_num` and `checkpoint_score_attr` in `RunConfig` using the `CheckpointStrategy` dataclass
* Adds two new fields to the `Result` object - `best_checkpoints` - a list of saved best checkpoints as determined by `CheckpointingConfig`.
As the integration logging callbacks are commonly used with AIR Trainers, they should be moved from the tune package to the air package. The old imports will still work, but raise a deprecation warning.
Follow up on our last discussion for supporting piecemeal fashion air users.
Only did for tensorflow for now, want to collect some feedback on API naming, package structure etc and I will add others.