前言
Hydra 是一个非常强大且易用的工具,适合用于复杂配置的管理场景。它的模块化、动态覆盖和实验记录功能尤其适合机器学习和深度学习项目。
Hydra,配置管理
一,是什么?
pytorch-lightning 是一个基于 PyTorch 的高层次深度学习框架,旨在简化深度学习模型的训练和评估过程。它通过将研究逻辑与工程代码分离,使代码更易读、易维护,并能快速实现各种功能如分布式训练、混合精度训练和模型调试。
二,主要特点
- 清晰的代码结构:将模型、训练步骤、验证逻辑等分离。
- 易用性:提供多种开箱即用的功能,例如日志记录、回调、分布式训练。
- 兼容性:与 PyTorch 原生代码兼容,可以逐步迁移。
- 自动化:减少样板代码,自动处理设备管理(如 GPU/TPU)、混合精度等。
- 扩展性:支持插件式扩展,例如自定义优化器、学习率调度器、回调等。
三,模块安装
使用 pip 安装:1
pip install lightning
四,核心组件
- LightningModule
核心类,用于定义模型、训练、验证和测试逻辑。 - Trainer
控制训练过程的类,支持设备管理、日志记录、回调等。 - DataModule
用于数据加载和预处理的模块(可选)。
五,使用示例:简单的分类任务
以下是使用 pytorch-lightning 训练一个 MNIST 数据集分类器的完整代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, datasets
import lightning as L
# 定义 LightningModule
class MNISTModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10)
)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.cross_entropy(logits, y)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.cross_entropy(logits, y)
acc = (logits.argmax(dim=-1) == y).float().mean()
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
# 定义 DataModule
class MNISTDataModule(L.LightningDataModule):
def __init__(self, data_dir="./data", batch_size=32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
def prepare_data(self):
datasets.MNIST(self.data_dir, train=True, download=True)
datasets.MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
if stage == "fit" or stage is None:
mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
self.train_set, self.val_set = random_split(mnist_full, [55000, 5000])
if stage == "test" or stage is None:
self.test_set = datasets.MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.train_set, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.val_set, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.test_set, batch_size=self.batch_size)
# 实例化模型和数据模块
model = MNISTModel()
data_module = MNISTDataModule()
# 实例化 Trainer 并开始训练
trainer = L.Trainer(
max_epochs=5,
accelerator="auto", # 自动选择 CPU 或 GPU
devices=1, # 单 GPU 或 CPU 训练
precision=16 # 混合精度训练
)
trainer.fit(model, data_module)
trainer.test(model, data_module)
代码说明
模块化设计:
- 模型逻辑放在 MNISTModel 中。
- 数据加载逻辑放在 MNISTDataModule 中。
Trainer:
- 自动管理训练过程,包括设备选择、混合精度和日志记录。
日志与可视化:
- self.log 用于记录指标。
- 可以轻松集成 TensorBoard 或其他日志工具。
六,常见功能扩展
多 GPU 训练:
设置 devices=4,并使用 accelerator=’gpu’。
回调:
添加自定义回调函数,如早停(EarlyStopping)或模型检查点(ModelCheckpoint)。
1 | from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint |
分布式训练:
支持分布式训练,设置 strategy=’ddp’。
七,适用场景
- 需要快速实验和调参。
- 需要跨设备(CPU/GPU/TPU)部署。
- 项目代码需要结构化管理和模块化设计。
通过 Lightning,深度学习的研发和生产过程会更加高效且易于扩展。