- Published on
Managing project configuration with hydra
What is Hydra?
Hydra
is a library built on top of OmegaConf
that uses .yaml
or (.yml
) files to pass arguments into a python script.
Why should I use it?
There already are nice libraries like click
or argparse
that can be used to pass arguments. However, they have a lot of manual work needed if your project's configuration starts getting heavy. You could also just use a .yaml
file and load it into your script, but then you would have to write a lot of boilerplate code to handle the configuration, especially if you want to nest modular configurations. As always, why reinvent the wheel?
How do I use it?
It's simlar to click
. You just add a decorator to your entry point method and specify the configuration it should use.
Simple Example
Here is a very simple example.
# main.yaml
defaults:
- dataset: mnist
- model: mlp
- train: default
# dataset/mnist.yaml
name: mnist
path: data/mnist
batch_size: 32
num_workers: 4
shuffle: true
pin_memory: true
drop_last: false
# model/mlp.yaml
name: mlp
num_hidden: 12
hidden_dim: 128
# train/default.yaml
learning_rate: 0.001
optimizer: Adam
weight_decay: 0.0
# main.py
import hydra
@hydra.main(config_path="conf", config_name="main")
def entry_point(config):
load_data(config.dataset)
load_model(config.model)
train_model(config.train)
...
if __name__ == "__main__":
entry_point()
Using class instantiation
The previous example is nice but there isn't much of a different from using click
or argparse
. But hydra
shines when you start using class instantiation.
# config.py
from dataclasses import dataclass
@dataclass
class DatasetConfig:
name: str = "mnist"
path: str = "data/mnist"
batch_size: int = 32
num_workers: int = 4
shuffle: bool = True
pin_memory: bool = True
drop_last: bool = False
@dataclass
class ModelConfig:
name: str = "mlp"
num_hidden: int = 12
hidden_dim: 128
@dataclass
class TrainConfig:
learning_rate: float = 0.001
optimizer: str = "Adam"
weight_decay: float = 0.0
You have your python classes, now you can instantiate the config objects by modifying the .yaml
files a bit.
# conf/main.yaml
# same as before!
defaults:
- dataset: mnist
- model: mlp
- train: default
# conf/dataset/mnist.yaml
_target_: config.DatasetConfig
name: mnist
path: data/mnist
batch_size: 32
num_workers: 4
shuffle: true
pin_memory: true
drop_last: false
# conf/model/mlp.yaml
_target_: config.ModelConfig
name: mlp
num_hidden: 12
hidden_dim: 128
# conf/train/default.yaml
_target_: config.TrainConfig
learning_rate: 0.001
optimizer: Adam
weight_decay: 0.0
This is assuming your config classes are reachable through config.[CONFIG_NAME]
.
Now you can just instantiate the config objects in your entry point method.
import hydra
@hydra.main(config_path="conf", config_name="main")
def entry_point(config):
# the type signatures are just for clarity
dataset_config: DatasetConfig = hydra.utils.instantiate(config.dataset)
model_config: ModelConfig = hydra.utils.instantiate(config.model)
train_config: TrainConfig = hydra.utils.instantiate(config.train)
load_data(dataset_config)
load_model(model_config)
train_model(train_config)
In some cases, you will notice that the dictionary values are not actually python dicrionaries but instead are OmegaConf
objects. You can specify the type of the object by using the _convert_
option in the hydra.utils.instantiate
call.
Instantiating with additional arguments
If you are a careful reader, you may have noticed that we are still using "Adam"
as a string and not the actual object. This still requires us to manually instantiate the optimizer in our code, which can be cumbersome. We can actually handle this within hydra as well, by using the hydra.utils.get_class
method to instantiate a callable object (i.e. method).
# conf/train/default.yaml
_target_: config.TrainConfig
optimizer:
_target_: hydra.utils.get_class
path: torch.nn.optim.Adam
opt_args:
learning_rate: 0.001
weight_decay: 0.0
Given that you already have a model and a raw config
object, you can instantiate the optimizer as follows.
from dataclasses import dataclass
from torch import nn
@dataclass
class TrainConfig:
learning_rate: float = 0.001
# now optimizer is a callable!
optimizer: callable = nn.optim.Adam
# we can have them separated as before, but this makes life easier.
# in fact, now the optimizer itself can be another nested config under something like conf/train/opt/adam.
opt_args: dict = {}
train_config = hydra.utils.instantiate(config.train)
train_config.optimizer(model.parameters(), **opt_args)
You can achieve a similar affect by using the _partial_
option when initalizing the object. Let's say that you just want the optimizer, you can then have something like this:
# conf/train/default.yaml
_target_: config.TrainConfig
optimizer:
_target_: torch.nn.optim.Adam
opt_args:
learning_rate: 0.001
weight_decay: 0.0
import hydra
optimizer = hydra.utils.instantiate(
config.train.optimizer,
_partial_=True
)(
model.parameters(),
**opt_args
)
In this case, the hydra.utils.instantitate
call returns a partial constructor, similar to if you were using functools.partial
.
Useful tips
Overriding configuration
By file
You can override the main configuration by providing arguments to the method. e.g.
python main.py dataset=another_dataset model=another_model
Assuming you have dataset/another_dataset.yaml
and model/another_model.yaml
in your conf
directory.
By parameter
If you want to change one parameter in a configuration file, you can use .
instead of /
to specify the exact parameter. e.g.
python main.py dataset.batch_size=512
Assuming you have a field called batch_size
in dataset.yaml
.
Using arbitrary number of configurations
You can also use an arbitary number of configurations by modifying the structure a bit.
# conf/main.yaml
defaults:
- dataset:
- mnist
- another_dataset
# ...
# conf/dataset/mnist.yaml
Mnist:
_target_: config.DatasetConfig
name: mnist
path: data/mnist
batch_size: 32
num_workers: 4
shuffle: true
pin_memory: true
drop_last: false
# conf/dataset/another_dataset.yaml
AnotherDataset:
_target_: config.DatasetConfig
name: Another-Dataset
path: data/another_dataset
batch_size: 64
num_workers: 4
shuffle: true
pin_memory: true
drop_last: false
Notice that the conf/main.yaml
now has a list of dictionaries that specify the filename of the configuration and each dataset configuration is now a nested dictionary where the top level key is an identifier (name). Using this stucture, you can now access these arbitrary configurations through config.dataset.values()
or config.dataset[DATASET_NAME]
if you know the name.