橦言无忌

一个不想改变世界的程序媛

Hydra in Python

前言

Hydra 是一个非常强大且易用的工具,适合用于复杂配置的管理场景。它的模块化、动态覆盖和实验记录功能尤其适合机器学习和深度学习项目。

Hydra,配置管理

一,什么是 Hydra?

Hydra 是一个强大的配置管理框架,主要用于 Python 项目。它允许开发者以模块化的方式管理配置,同时支持多层次的配置继承、覆盖、动态命令行参数修改等功能,适用于机器学习、深度学习等需要大量配置管理的场景。

二,Hydra 的核心特性

1.    配置模块化:通过分离配置文件实现灵活的配置管理。
2.    配置继承:支持继承基础配置并在此基础上进行修改。
3.    动态配置覆盖:可以在运行时通过命令行覆盖默认配置。
4.    实验管理:内置实验记录功能,自动保存每次运行的配置和结果。
5.    支持 YAML 配置格式:更人性化和易读。

三,安装

1
pip install hydra-core

四,常用功能示例

1,基本用法

配置文件:config.yaml

1
2
3
4
5
6
database:
driver: mysql
host: localhost
port: 3306
username: root
password: root

主程序:app.py

1
2
3
4
5
6
7
8
9
10
11
import hydra
from omegaconf import DictConfig

@hydra.main(config_path=".", config_name="config")
def main(cfg: DictConfig):
print(f"Database driver: {cfg.database.driver}")
print(f"Host: {cfg.database.host}")
print(f"Port: {cfg.database.port}")

if __name__ == "__main__":
main()

运行:

1
python app.py

输出:

1
2
3
Database driver: mysql
Host: localhost
Port: 3306

2,动态覆盖配置

通过命令行修改默认配置:

1
python app.py database.driver=postgresql database.port=5432

输出:

1
2
3
Database driver: postgresql
Host: localhost
Port: 5432

3,配置继承

Hydra 支持通过多个 YAML 文件实现配置继承。

基础配置:config.yaml

1
2
3
4
5
defaults:
- db: mysql

database:
timeout: 30

子配置:db/mysql.yaml

1
2
3
driver: mysql
host: localhost
port: 3306

子配置:db/postgresql.yaml

1
2
3
driver: postgresql
host: localhost
port: 5432

主程序:app.py

1
2
3
4
5
6
7
8
9
import hydra
from omegaconf import DictConfig

@hydra.main(config_path=".", config_name="config")
def main(cfg: DictConfig):
print(cfg.database)

if __name__ == "__main__":
main()

运行:

1
python app.py

输出:

1
{'driver': 'mysql', 'host': 'localhost', 'port': 3306, 'timeout': 30}

运行时覆盖默认配置:

1
python app.py db=postgresql

输出:

1
{'driver': 'postgresql', 'host': 'localhost', 'port': 5432, 'timeout': 30}

4,多任务配置

Hydra 支持为不同任务配置不同的参数。

配置文件:config.yaml

1
2
3
4
5
6
7
8
9
10
defaults:
- task: classification

task:
classification:
model: resnet50
dataset: cifar10
regression:
model: linear
dataset: boston

主程序:app.py

1
2
3
4
5
6
7
8
9
10
import hydra
from omegaconf import DictConfig

@hydra.main(config_path=".", config_name="config")
def main(cfg: DictConfig):
print(f"Task: {cfg.task.model}")
print(f"Dataset: {cfg.task.dataset}")

if __name__ == "__main__":
main()

运行不同任务:

1
2
python app.py task=classification
python app.py task=regression

5,实验管理与工作目录

Hydra 默认会为每次运行创建一个独立的工作目录并保存配置和输出。
主程序:app.py

1
2
3
4
5
6
7
8
9
10
import hydra
from omegaconf import DictConfig
import os

@hydra.main(config_path=".", config_name="config")
def main(cfg: DictConfig):
print(f"Working directory: {os.getcwd()}")

if __name__ == "__main__":
main()

运行:

1
python app.py

默认输出:

1
Working directory: /path/to/outputs/2024-12-06/11-42-33

五,多个配置文件夹举例

在 Python 中使用 Hydra 时,可以处理多个配置文件夹,通过将它们组织到不同的目录中并设置 config_path 和 defaults 来指定配置来源。

文件目录

假设我们有以下目录结构:

1
2
3
4
5
6
7
8
9
10
my_app/
├── conf/
│ ├── db/
│ │ ├── mysql.yaml
│ │ ├── postgresql.yaml
│ ├── task/
│ │ ├── classification.yaml
│ │ ├── regression.yaml
│ ├── config.yaml
├── app.py

  • db 文件夹:包含数据库相关的配置。
  • task 文件夹:包含任务相关的配置。
  • config.yaml:主配置文件,用于整合子配置。
  • app.py:主程序文件。

配置文件内容

1,主配置文件:conf/config.yaml

1
2
3
4
5
6
defaults:
- db: mysql # 默认使用 mysql 数据库配置
- task: classification # 默认使用 classification 任务配置

global:
debug: true

2,数据库配置文件:conf/db/mysql.yaml

1
2
3
4
5
driver: mysql
host: localhost
port: 3306
username: root
password: password

3,数据库配置文件:conf/db/postgresql.yaml

1
2
3
4
5
driver: postgresql
host: localhost
port: 5432
username: postgres
password: password

4,任务配置文件:conf/task/classification.yaml

1
2
3
model: resnet50
dataset: cifar10
epochs: 50

5,任务配置文件:conf/task/regression.yaml

1
2
3
model: linear
dataset: boston
epochs: 100

主程序文件:app.py

1
2
3
4
5
6
7
8
9
10
import hydra
from omegaconf import DictConfig, OmegaConf

@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig):
print("Configuration:")
print(OmegaConf.to_yaml(cfg))

if __name__ == "__main__":
main()

运行和结果

直接运行

1
python app.py

输出:

1
2
3
4
5
6
7
8
9
10
11
12
db:
driver: mysql
host: localhost
port: 3306
username: root
password: password
task:
model: resnet50
dataset: cifar10
epochs: 50
global:
debug: true

覆盖数据库配置

指定 PostgreSQL 数据库

1
python app.py db=postgresql

输出:

1
2
3
4
5
6
7
8
9
10
11
12
db:
driver: postgresql
host: localhost
port: 5432
username: postgres
password: password
task:
model: resnet50
dataset: cifar10
epochs: 50
global:
debug: true

覆盖任务配置

指定回归任务:

1
python app.py task=regression

输出

1
2
3
4
5
6
7
8
9
10
11
12
db:
driver: mysql
host: localhost
port: 3306
username: root
password: password
task:
model: linear
dataset: boston
epochs: 100
global:
debug: true

动态覆盖多个参数

1
python app.py db=postgresql task=regression global.debug=false

输出:
1
2
3
4
5
6
7
8
9
10
11
12
db:
driver: postgresql
host: localhost
port: 5432
username: postgres
password: password
task:
model: linear
dataset: boston
epochs: 100
global:
debug: false

配置多个文件夹的关键点

1.    文件夹结构清晰:
•    根据功能将配置文件组织到不同的子文件夹(例如 db 和 task)。
2.    defaults 声明:
•    在主配置文件中通过 defaults 指定默认使用的子配置文件。
3.    动态覆盖:
•    使用 命令行参数 实现动态覆盖,方便切换配置。
4.    扩展配置路径(高级):

如果需要加载多个根路径的配置,可以通过环境变量或者 hydra.searchpath 添加多个路径。例如:

1
2
3
4
5
6
@hydra.main(config_path=None, config_name="config")
def main(cfg: DictConfig):
hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize(config_path="conf")
cfg = hydra.compose(config_name="config", overrides=["+hydra.searchpath=[/custom/path]"])
print(OmegaConf.to_yaml(cfg))

这种方式可以在运行时动态扩展配置源路径,非常适合复杂场景。

// 代码折叠