Published on

How to checkpoint and resume training in pytorch

Why do I need to checkpoint?

If you're training a deep learning model, you might have run into the issue of having to restart your training from scratch because of a crash or some other issue. This can be a huge waste of time and resources, especially if you're training on a large dataset or a complex model. You can save the time at the cost of disk space by checkpointing your model during training. This way, if something goes wrong, you can resume training from the last checkpoint instead of starting from scratch.

What should I checkpoint?

To completely resume training, you need to save the following:

  • Model weights
  • Optimizer state
  • Training step/epoch

Using these three pieces of information, you can completely resume training from where you left off.

Checkpointing and resuming training

Vanilla Pytorch

import torch
from safe_tensors import save_file

...

def save_checkpoint(model, optimizer, step, filename):
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'step': step
    }

    ## Save using torch.save -- it's the same a pickling, so we can just save the whole thing
    torch.save(checkpoint, f"{filename}.pt")

    ## Save using safetensors -- more interoperable!, generally recommended
    save_file(model.state_dict(), f"{filename}/{step}/model.safetensors")
    save_file(optimizer.state_dict(), f"{filename}/{step}/optimizer.safetensors")

num_epochs = 10
step = 0

for epoch in range(num_epochs):
    for inputs, targets in dataloader:
        ## Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        ## Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        step += 1

        ## Checkpoint every 100 steps
        if step % 100 == 0:
            print(f"Step {step}, Loss: {loss.item():.4f}")
            checkpoint_dir = "checkpoints"
            os.makedirs(checkpoint_dir, exist_ok=True)
            checkpoint_filename = os.path.join(checkpoint_dir, f"checkpoint_{step}")
            save_checkpoint(model, optimizer, step, checkpoint_filename)
...

Notes

  • safetensors is generally recommended for speed and safety.

Using transformers.Trainer

from transformers import Trainer, TrainingArguments
from pathlib import Path

...

logdir = Path("logs")

training_args = TrainingArguments(
    output_dir=str(logdir),
    seed=random_state,
)

trainer = Trainer(
    model_init=model_init,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=tr_dset,
    eval_dataset=va_dset,
    data_collator=collate_fn,
)

can_continue = any(logdir.glob("checkpoint-*"))
if can_continue:
    print(f"Picking up from checkpoint {str(logdir.resolve())}")
else:
    print(f"No checkpoint at {str(logdir.resolve())}, starting fresh")
trainer.train(resume_from_checkpoint=can_continue)

## use the model...
trainer.model

...