手撸AI-4: Accelerate库分布式训练详解

慈云数据 2024-03-13 技术支持 77 0

一. 引言

        Accelerate 是 Hugging Face 公司开发的一个 Python 库,旨在简化并优化在各种环境中进行深度学习训练的过程,包括单机、多 GPU、TPU 和各种分布式训练环境。这个库提供了一种通用的 API,可以方便地将原来只能在单个设备上运行的代码扩展到多设备和分布式环境。在平常我们阅读源码或者编写训练流程的时候acceletate尤为重要.
官方文档和教程icon-default.png?t=N7T8https://huggingface.co/docs/accelerate/basic_tutorials

 二. Accelerate库安装及使用

 2.1 安装 

        这里要注意pytorch的版本兼容问题,不指定版本(accelerate 0.27.2)的话需要torch>=1.10.0,而作者的torch版本是torch 1.9.1和cuda111版本对应,因此选择安装 accelerate 的 0.5.0 版本.但是该版本无法使用accelerate env命令查看环境状态.

pip install accelerate -i https://pypi.tuna.tsinghua.edu.cn/simple

2.2 创建实例

       导入Accelerator主类并在accelerator对象中实例化一个。

from accelerate import Accelerator
accelerator = Accelerator()

        将其添加到训练脚本的开头,因为它将初始化分布式训练所需的所有内容。无需指明您所处的环境类型(具有 GPU 的单台机器、具有多个 GPU 的机器或具有多个 GPU 或 TPU 的多台机器),库会自动检测到这一点。

2.3 删除模型和输入数据.to(device)和.cuda()

Accelerator对象将负责将这些对象放置在适合的设备上。

2.3 使用Accelerator.prepare 方法

创建这些对象后,在开始实际的训练循环之前,将所有与训练相关的 PyTorch 对象(优化器、模型、数据加载器、验证加载器, 学习率调度程序)传递给prepare方法

model, optimizer, train_dataloader, val_loader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader,val_loader, lr_scheduler
)

2.4 将loss.backward()行替换为accelerator.backward(loss)

2.5 每一epoch下模型的评测和保存

if acc > best_acc:
    # 更新 best_acc
    best_acc = acc
    
    if args.save:
        # 如果 save 参数为真,保存模型权重
        if not os.path.exists('save'):
            os.mkdir('save')
        
        # 等待所有进程完成
        accelerator.wait_for_everyone()
        
        # 获取未包装的模型
        unwrapped_model = accelerator.unwrap_model(model)
        
        # 保存权重
        accelerator.save(unwrapped_model.state_dict(), os.path.join(
            'save', f'ep{epoch}_acc{acc * 1000:.0f}.pth'))

三. 分布式脚本使用

3.3 快速配置(默认)

通过指定--config_file 标志,您可以指定配置文件的替代位置

accelerate config

设置结束时,default_config.yaml文件将保存在您的缓存文件夹中以用于 🤗 Accelerate。该缓存文件夹是(按优先级降序排列):

  • 环境变量的内容以加速HF_HOME为后缀。
  • 如果不存在,则环境变量的内容XDG_CACHE_HOME以 Huggingface/accelerate为后缀。
  • 如果也不存在,则文件夹~/.cache/huggingface/accelerate

    3.4 测试配置

    默认配置测试:

    accelerate test

    也可通过指定--config_file 标志指定配置测试.

    3.5 运行主函数

    accelerate launch --config_file config.yaml main.py --args
微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon