mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[GCP] Update GCP TPU config (#18634)
* [autoscaler] Update GCP TPU config * Preemptible by default * Remove libtpu link from head node * Workaround
This commit is contained in:
parent
9c9b482661
commit
573c66a755
1 changed files with 12 additions and 6 deletions
|
@ -32,9 +32,9 @@ available_node_types:
|
|||
# Support for TPU pods will be added in the future.
|
||||
acceleratorType: v2-8
|
||||
runtimeVersion: v2-alpha
|
||||
# Uncomment to use preemptible TPUs
|
||||
# schedulingConfig:
|
||||
# preemptible: true
|
||||
schedulingConfig:
|
||||
# Set to false to use non-preemptible TPUs
|
||||
preemptible: true
|
||||
|
||||
provider:
|
||||
type: gcp
|
||||
|
@ -51,15 +51,21 @@ head_node_type: ray_head_default
|
|||
# Compute instances have python 3.7, but TPUs have 3.8 - need to update
|
||||
# Install Jax and other dependencies on the Compute head node
|
||||
head_setup_commands:
|
||||
- conda create -y -n "ray" python=3.8.5 && sudo update-alternatives --install /opt/conda/bin/python python /opt/conda/envs/ray/bin/python 10 && sudo update-alternatives --install /opt/conda/bin/pip pip /opt/conda/envs/ray/bin/pip 10
|
||||
- export PATH="$PATH:/opt/conda/envs/ray/bin" && echo 'export PATH="$PATH:/opt/conda/envs/ray/bin"' >> ~/.bashrc
|
||||
- python -m pip install --upgrade "jax[cpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
# Two first lines are a workaround for ssh timing out
|
||||
- sleep 2
|
||||
- sleep 2
|
||||
- sudo chown -R $(whoami) /opt/conda/*
|
||||
- conda create -y -n "ray" python=3.8.5
|
||||
- conda activate ray && echo 'conda activate ray' >> ~/.bashrc
|
||||
- python -m pip install --upgrade pip
|
||||
- python -m pip install --upgrade "jax[cpu]==0.2.14"
|
||||
- python -m pip install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku google-api-python-client cryptography tensorboardX ray[default]
|
||||
- python -m pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
|
||||
- git clone https://github.com/Yard1/swarm-jax.git && cd swarm-jax && python -m pip install .
|
||||
|
||||
# Install Jax and other dependencies on TPU
|
||||
worker_setup_commands:
|
||||
- pip3 install --upgrade pip
|
||||
- pip3 install --upgrade "jax[tpu]==0.2.14" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
- pip3 install --upgrade fabric dataclasses optax==0.0.6 git+https://github.com/deepmind/dm-haiku tensorboardX ray[default]
|
||||
- python3 -c "import jax; jax.device_count(); jax.numpy.add(1, 1)" # test if Jax has been installed correctly
|
||||
|
|
Loading…
Add table
Reference in a new issue