Published on

Managing project configuration with hydra

What is Hydra?

Hydra is a library built on top of OmegaConf that uses .yaml or (.yml) files to pass arguments into a python script.

Why should I use it?

There already are nice libraries like click or argparse that can be used to pass arguments. However, they have a lot of manual work needed if your project's configuration starts getting heavy. You could also just use a .yaml file and load it into your script, but then you would have to write a lot of boilerplate code to handle the configuration, especially if you want to nest modular configurations. As always, why reinvent the wheel?

How do I use it?

It's simlar to click. You just add a decorator to your entry point method and specify the configuration it should use.

Simple Example

Here is a very simple example.

# main.yaml
defaults:
  - dataset: mnist
  - model: mlp
  - train: default

# dataset/mnist.yaml
name: mnist
path: data/mnist
batch_size: 32
num_workers: 4
shuffle: true
pin_memory: true
drop_last: false

# model/mlp.yaml
name: mlp
num_hidden: 12
hidden_dim: 128

# train/default.yaml
learning_rate: 0.001
optimizer: Adam
weight_decay: 0.0
# main.py
import hydra

@hydra.main(config_path="conf", config_name="main")
def entry_point(config):
    load_data(config.dataset)
    load_model(config.model)
    train_model(config.train)
   ...

if __name__ == "__main__":
    entry_point()

Using class instantiation

The previous example is nice but there isn't much of a different from using click or argparse. But hydra shines when you start using class instantiation.

# config.py
from dataclasses import dataclass
@dataclass
class DatasetConfig:
    name: str = "mnist"
    path: str = "data/mnist"
    batch_size: int = 32
    num_workers: int = 4
    shuffle: bool = True
    pin_memory: bool = True
    drop_last: bool = False

@dataclass
class ModelConfig:
  name: str = "mlp"
  num_hidden: int = 12
  hidden_dim: 128

@dataclass
class TrainConfig:
  learning_rate: float = 0.001
  optimizer: str = "Adam"
  weight_decay: float = 0.0

You have your python classes, now you can instantiate the config objects by modifying the .yaml files a bit.

# conf/main.yaml
# same as before!
defaults:
  - dataset: mnist
  - model: mlp
  - train: default

# conf/dataset/mnist.yaml
_target_: config.DatasetConfig
name: mnist
path: data/mnist
batch_size: 32
num_workers: 4
shuffle: true
pin_memory: true
drop_last: false

# conf/model/mlp.yaml
_target_: config.ModelConfig
name: mlp
num_hidden: 12
hidden_dim: 128

# conf/train/default.yaml
_target_: config.TrainConfig
learning_rate: 0.001
optimizer: Adam
weight_decay: 0.0

This is assuming your config classes are reachable through config.[CONFIG_NAME].

Now you can just instantiate the config objects in your entry point method.

import hydra

@hydra.main(config_path="conf", config_name="main")
def entry_point(config):
    # the type signatures are just for clarity
    dataset_config: DatasetConfig = hydra.utils.instantiate(config.dataset)
    model_config: ModelConfig = hydra.utils.instantiate(config.model)
    train_config: TrainConfig = hydra.utils.instantiate(config.train)
    load_data(dataset_config)
    load_model(model_config)
    train_model(train_config)

In some cases, you will notice that the dictionary values are not actually python dicrionaries but instead are OmegaConf objects. You can specify the type of the object by using the _convert_ option in the hydra.utils.instantiate call.

Instantiating with additional arguments

If you are a careful reader, you may have noticed that we are still using "Adam" as a string and not the actual object. This still requires us to manually instantiate the optimizer in our code, which can be cumbersome. We can actually handle this within hydra as well, by using the hydra.utils.get_class method to instantiate a callable object (i.e. method).

# conf/train/default.yaml
_target_: config.TrainConfig
optimizer:
  _target_: hydra.utils.get_class
  path: torch.nn.optim.Adam
opt_args:
  learning_rate: 0.001
  weight_decay: 0.0

Given that you already have a model and a raw config object, you can instantiate the optimizer as follows.

from dataclasses import dataclass
from torch import nn

@dataclass
class TrainConfig:
    learning_rate: float = 0.001
    # now optimizer is a callable!
    optimizer: callable = nn.optim.Adam
    # we can have them separated as before, but this makes life easier.
    # in fact, now the optimizer itself can be another nested config under something like conf/train/opt/adam.
    opt_args: dict = {}

train_config = hydra.utils.instantiate(config.train)
train_config.optimizer(model.parameters(), **opt_args)

You can achieve a similar affect by using the _partial_ option when initalizing the object. Let's say that you just want the optimizer, you can then have something like this:

# conf/train/default.yaml
_target_: config.TrainConfig
optimizer:
  _target_: torch.nn.optim.Adam
opt_args:
  learning_rate: 0.001
  weight_decay: 0.0
import hydra

optimizer = hydra.utils.instantiate(
  config.train.optimizer,
  _partial_=True
)(
  model.parameters(),
  **opt_args
)

In this case, the hydra.utils.instantitate call returns a partial constructor, similar to if you were using functools.partial.

Useful tips

Overriding configuration

By file

You can override the main configuration by providing arguments to the method. e.g.

python main.py dataset=another_dataset model=another_model

Assuming you have dataset/another_dataset.yaml and model/another_model.yaml in your conf directory.

By parameter

If you want to change one parameter in a configuration file, you can use . instead of / to specify the exact parameter. e.g.

python main.py dataset.batch_size=512

Assuming you have a field called batch_size in dataset.yaml.

Using arbitrary number of configurations

You can also use an arbitary number of configurations by modifying the structure a bit.

# conf/main.yaml
defaults:
  - dataset: 
    - mnist
    - another_dataset
    # ...  

# conf/dataset/mnist.yaml
Mnist:
  _target_: config.DatasetConfig
  name: mnist
  path: data/mnist
  batch_size: 32
  num_workers: 4
  shuffle: true
  pin_memory: true
  drop_last: false

# conf/dataset/another_dataset.yaml
AnotherDataset:
  _target_: config.DatasetConfig
  name: Another-Dataset
  path: data/another_dataset
  batch_size: 64
  num_workers: 4
  shuffle: true
  pin_memory: true
  drop_last: false

Notice that the conf/main.yaml now has a list of dictionaries that specify the filename of the configuration and each dataset configuration is now a nested dictionary where the top level key is an identifier (name). Using this stucture, you can now access these arbitrary configurations through config.dataset.values() or config.dataset[DATASET_NAME] if you know the name.

Useful links