- Published on
How to checkpoint and resume training in pytorch
Why do I need to checkpoint?
If you're training a deep learning model, you might have run into the issue of having to restart your training from scratch because of a crash or some other issue. This can be a huge waste of time and resources, especially if you're training on a large dataset or a complex model. You can save the time at the cost of disk space by checkpointing your model during training. This way, if something goes wrong, you can resume training from the last checkpoint instead of starting from scratch.
What should I checkpoint?
To completely resume training, you need to save the following:
- Model weights
- Optimizer state
- Training step/epoch
Using these three pieces of information, you can completely resume training from where you left off.
Checkpointing and resuming training
Vanilla Pytorch
import torch
from safe_tensors import save_file
...
def save_checkpoint(model, optimizer, step, filename):
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'step': step
}
## Save using torch.save -- it's the same a pickling, so we can just save the whole thing
torch.save(checkpoint, f"{filename}.pt")
## Save using safetensors -- more interoperable!, generally recommended
save_file(model.state_dict(), f"{filename}/{step}/model.safetensors")
save_file(optimizer.state_dict(), f"{filename}/{step}/optimizer.safetensors")
num_epochs = 10
step = 0
for epoch in range(num_epochs):
for inputs, targets in dataloader:
## Forward pass
outputs = model(inputs)
loss = criterion(outputs, targets)
## Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
step += 1
## Checkpoint every 100 steps
if step % 100 == 0:
print(f"Step {step}, Loss: {loss.item():.4f}")
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_filename = os.path.join(checkpoint_dir, f"checkpoint_{step}")
save_checkpoint(model, optimizer, step, checkpoint_filename)
...
Notes
safetensors
is generally recommended for speed and safety.
transformers.Trainer
Using from transformers import Trainer, TrainingArguments
from pathlib import Path
...
logdir = Path("logs")
training_args = TrainingArguments(
output_dir=str(logdir),
seed=random_state,
)
trainer = Trainer(
model_init=model_init,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=tr_dset,
eval_dataset=va_dset,
data_collator=collate_fn,
)
can_continue = any(logdir.glob("checkpoint-*"))
if can_continue:
print(f"Picking up from checkpoint {str(logdir.resolve())}")
else:
print(f"No checkpoint at {str(logdir.resolve())}, starting fresh")
trainer.train(resume_from_checkpoint=can_continue)
## use the model...
trainer.model
...