Using checkpoints for ML models
Many machine learning (ML) frameworks provide methods to store the state of a trained model through the use of a process called checkpoints. When a model is trained on a data set, it may change internal parameters such as weights in order to improve the performance of the model on the provided training sample as well as future samples in general. These changes may be done for each sample provided or adjusted after a small batch of training samples, depending on the model and training operation.
Training ML models on GPUs is fast for several reasons, but one reason is that the model parameters and the training data may be pre-allocated in the GPU's dedicated memory, and remain there for when it is used. This reduces overhead with moving data between GPU, main memory and disk, so ideally this process is not interrupted.
However, sometimes problems may occur when you train your model for a long time. Depending on the framework, the program could allocate too much memory on the GPU, if it is allowed to grow. Splitting up a job might sometimes be a workaround in this scenario, although it is recommended to tune the ML hyperparameters, such as batch size, to work with the GPU in use.
Another potential problem is going over the limit of the time a SLURM job may take. Currently, the ALICE cluster has partitions that allow up to 7 days of running time. Longer jobs are automatically terminated to provide time to other users. Therefore, progress from training a model may be lost if the model state has not been saved on a permanent storage.
Gaining another slot in the scheduler may take some time, but the partitions with shorter time limits have priority over long jobs.
This tutorial mostly provides an overview of what the known options are for performing checkpointing in machine learning models. We also discuss how to make this fit in SLURM’s scheduler with time limits. Some parts may be best suited for more experienced users.
Existing ML models
Many existing models written in TensorFlow or PyTorch, such as models provided with research papers and models used in performance tests (like those provided by NVIDIA) often already contain the logic necessary to save and read checkpoints. Often, checkpoints are provided for the model, so that you can continue training from a known point or skip the training and use the pre-trained model for actual classification or inference on a (new) data set.
The documentation of the model should usually describe how the checkpoints are loaded. In most cases, a separate script can download existing checkpoints and place them in a proper location. The model automatically detects those checkpoints to read them in. Sometimes, you may need to adjust some arguments. Check the command line arguments by adding --help
to the end of the command used to run the model and find the arguments related to checkpoints. If that does not work, other documentation may be provided by the developers to describe the necessary steps.
For some models, you can prepend an environment variable like RESUME=true
in front of the command to allow the model to use existing checkpoints found in the appropriate directory. Again, this depends on the model and on the distribution method for the checkpoints.
Most existing models automatically save checkpoints in the directory appropriate for that model. If you cannot find them, check the command line arguments again.
TensorFlow
If you are creating your own ML model in TensorFlow, or extending an existing model, you will probably want to add checkpoint writing and loading if you haven't done so already. TensorFlow checkpoints track the values of the model's parameters. They may also be used to inspect the model state using tensorboard, which helps investigating the stability of the learning progress.
If you use the Keras API for TensorFlow in your model, then saving a checkpoint is as simple as calling the function save_weights
on your model with a checkpoint name. Loading a checkpoint is similarly done by calling load_weights
with an existing checkpoint name on the model object.
These functions are just a one-time thing. To save more often with just Keras, you could perform the saving within a train_step
method of the model class (by calling self.save_weights("ckpt")
anywhere in the training step, or write your own training loop) and calling the save_weights
function depending on the step count, but this is still far from customizable as to how many checkpoints you want to retain.
A more thorough option is to use tf.train.Checkpoint
to define how the checkpoint should look like, and tf.train.CheckpointManager
to perform the saving. This works with Keras and non-Keras models, and is also able to handle restoring from the most recent checkpoint. Additional save parameters can be used to keep track of other information, for example what the next example from the dataset will be to use for training.
The following code is based on the TensorFlow tutorial guide on training checkpoints, and can be seen as a demonstration of using the checkpoint manager for regular saving and restoring:
For older models from TensorFlow v1 such as tf.estimator
models, you may be able to provide a RunConfig
to the model in order to specify the checkpoint options like frequency of checkpointing and number of checkpoints to keep. See old TensorFlow guides for details. Otherwise, you may need to adjust the model to use checkpoint/saver hooks. At this point, it may be worthwhile to convert the model to TensorFlow v2. A migration guide for checkpoint saving from TensorFlow v1 to TensorFlow v2 can be found at Migrate checkpoint saving for TensorFlow Core.
PyTorch
Like TensorFlow, PyTorch has methods for saving checkpoints. It does not currently have a class like the checkpoint manager to handle describing the format of checkpoints, tracking options for how many checkpoints to keep, combining all into one place. There are several tutorials that may help with implementing checkpoints in PyTorch models in different ways:
These tutorials do not provide a complete overview of how to save checkpoints for the models during a training loop or how to load the most recent checkpoint only if it exists. However, some idea can be extracted here to keep track of the training epochs and to perform the proper calls to torch.save
and torch.load
during and before the training, by implementing a training loop for optimizing the model parameters.
SLURM
As mentioned in the introduction, one reason to use checkpoints is to split up a long-running training run in order to use several jobs that can fit within the time limits of the queues of the SLURM scheduler. Each job can then use the most recent checkpoint to continue training from the state the model was in at the end of its previous job, or if the job failed during training – due to time limits, for example – a checkpoint made at an earlier step that was maintained.
It is possible to automate the repetitive scheduling of jobs. Note: One of the rules for ALICE and SHARK is to only schedule jobs from the login node; do not schedule jobs recursively from a SLURM script (see Best Practices)!
First, however, we discuss how this splitting up can be performed. Usually, you can provide an argument to your existing ML model program to limit the number of epochs to perform. But how do we know how many epochs can fit in the time limit imposed by SLURM? We can try to find this out by trial-and-error. Run the job for a number of epochs, saving checkpoints all the time. Once it hits the time limit, check the logs to find how many epochs have been trained. Then you can limit the job to this number of epochs. Hopefully the learning speed is not too variable.
Another option is to run one epoch, see how much time the job took, and divide the time limit by this elapsed time. If properly calculated and rounded down, this should give a good estimate for the number of epochs that fit in the time limit.
The final option (which is probably the easiest to do) is to run as many epochs that you want, but to check from time to time whether the script’s run time is reaching the SLURM time limit. You can then stop training early to ensure the checkpoints are saved and any other “tear down” steps are run.
Next, ensure your program can save and load checkpoints, preferably still saving them after a batch of epochs is complete. The program should also resume from the most recent checkpoint. You need to store or copy the checkpoints on a persistent storage location such as /data1
so that they can be reused between jobs. Finally, don’t start multiple jobs on the same checkpoints concurrently because you may lose track on which checkpoint is from which (diverging) training run.
An additional step could be to automate the above within one job. Perform one of the options mentioned above, preferably one where you can be fairly certain that the epochs will finish before the job is stopped by the scheduler, and thus any further lines in your batch script (for example to perform cleanup or to copy of results) should be able to be run. You can then even keep checkpoints in the local scratch created for the job and copy them back at the end of the script, instead of having them sent to network storage during the run. Then, you can determine if the job finished and start a new one, first copying over the existing checkpoints to resume from.
In the TensorFlow example above, we already implement some automation. The third option is used here: we request the time limit from SLURM (with the command squeue -h -j $SLURM_JOB_ID -o %L
) and calculate the expected end time based on the limit. We compare the current time to the projected end time, with 60 seconds leeway.
The example job should be able to be run over and over now, with an appropriate SLURM batch script. To try it out, save the Python code above in a file named checkpoint_demo.py
and the batch script below in checkpoint_demo.slurm
:
ALICE
SHARK
On SHARK, TensorFlow is not available as a module. You will have to install it yourself locally either through pip (in python virtual environment) or in a conda environment. Using a container will not work because the checkpointing needs access to slurm commands.