橦言无忌

一个不想改变世界的程序媛

lightning in Python

前言

Hydra 是一个非常强大且易用的工具,适合用于复杂配置的管理场景。它的模块化、动态覆盖和实验记录功能尤其适合机器学习和深度学习项目。

Hydra,配置管理

一,是什么?

pytorch-lightning 是一个基于 PyTorch 的高层次深度学习框架,旨在简化深度学习模型的训练和评估过程。它通过将研究逻辑与工程代码分离,使代码更易读、易维护,并能快速实现各种功能如分布式训练、混合精度训练和模型调试。

二,主要特点

  • 清晰的代码结构:将模型、训练步骤、验证逻辑等分离。
  • 易用性:提供多种开箱即用的功能,例如日志记录、回调、分布式训练。
  • 兼容性:与 PyTorch 原生代码兼容,可以逐步迁移。
  • 自动化:减少样板代码,自动处理设备管理(如 GPU/TPU)、混合精度等。
  • 扩展性:支持插件式扩展,例如自定义优化器、学习率调度器、回调等。

三,模块安装

使用 pip 安装:

1
pip install lightning

四,核心组件

  • LightningModule
    核心类,用于定义模型、训练、验证和测试逻辑。
  • Trainer
    控制训练过程的类,支持设备管理、日志记录、回调等。
  • DataModule
    用于数据加载和预处理的模块(可选)。

五,使用示例:简单的分类任务

以下是使用 pytorch-lightning 训练一个 MNIST 数据集分类器的完整代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, datasets
import lightning as L

# 定义 LightningModule
class MNISTModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 10)
)

def forward(self, x):
return self.model(x)

def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.cross_entropy(logits, y)
self.log("train_loss", loss)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.cross_entropy(logits, y)
acc = (logits.argmax(dim=-1) == y).float().mean()
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)

# 定义 DataModule
class MNISTDataModule(L.LightningDataModule):
def __init__(self, data_dir="./data", batch_size=32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])

def prepare_data(self):
datasets.MNIST(self.data_dir, train=True, download=True)
datasets.MNIST(self.data_dir, train=False, download=True)

def setup(self, stage=None):
if stage == "fit" or stage is None:
mnist_full = datasets.MNIST(self.data_dir, train=True, transform=self.transform)
self.train_set, self.val_set = random_split(mnist_full, [55000, 5000])
if stage == "test" or stage is None:
self.test_set = datasets.MNIST(self.data_dir, train=False, transform=self.transform)

def train_dataloader(self):
return DataLoader(self.train_set, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.val_set, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.test_set, batch_size=self.batch_size)

# 实例化模型和数据模块
model = MNISTModel()
data_module = MNISTDataModule()

# 实例化 Trainer 并开始训练
trainer = L.Trainer(
max_epochs=5,
accelerator="auto", # 自动选择 CPU 或 GPU
devices=1, # 单 GPU 或 CPU 训练
precision=16 # 混合精度训练
)
trainer.fit(model, data_module)
trainer.test(model, data_module)

代码说明

模块化设计:

  • 模型逻辑放在 MNISTModel 中。
  • 数据加载逻辑放在 MNISTDataModule 中。

    Trainer:

  • 自动管理训练过程,包括设备选择、混合精度和日志记录。

    日志与可视化:

  • self.log 用于记录指标。
  • 可以轻松集成 TensorBoard 或其他日志工具。

六,常见功能扩展

多 GPU 训练:

设置 devices=4,并使用 accelerator=’gpu’。

回调:

添加自定义回调函数,如早停(EarlyStopping)或模型检查点(ModelCheckpoint)。

1
2
3
4
5
6
7
8
9
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint

early_stop_callback = EarlyStopping(monitor="val_loss", patience=3)
checkpoint_callback = ModelCheckpoint(monitor="val_loss")

trainer = L.Trainer(
callbacks=[early_stop_callback, checkpoint_callback],
max_epochs=10,
)

分布式训练:

支持分布式训练,设置 strategy=’ddp’。

七,适用场景

  • 需要快速实验和调参。
  • 需要跨设备(CPU/GPU/TPU)部署。
  • 项目代码需要结构化管理和模块化设计。

通过 Lightning,深度学习的研发和生产过程会更加高效且易于扩展。

// 代码折叠