【Python】科研代码学习:十四 wandb (可视化AI工具)

慈云数据 2024-04-01 技术支持 51 0

【Python】科研代码学习:十四 wandb[可视化AI工具]

  • wandb 介绍
  • 注册账号
  • 使用 `HF Trainer` + `wandb` 训练
  • 低级 API

    wandb 介绍

    • 【wandb官网

      wandb 是 Weights & Biases 的缩写(w and b)

    • 核心作用
      • 可视化重要参数
      • 云端存储
      • 提供各种工具
      • 可以和其他工具配合使用,比如下面的 pytorch, HF transformers, tensorflow, keras 等等

        在这里插入图片描述

      • 可以在里面使用 matplotlib
      • 貌似是 tensorboard 的上位替代

        注册账号

        • 首先我们需要去官网注册账号,貌似不能使用***

          注册号后,按照教程创建一个团队,然后来到这个界面

          可以按照这个 Quickstart 的样例走一下。选择 Track Runs,接下来可以选择使用哪个工具训练的模型

          然后需要 pip install wandb 导包,以及 wandb login 登录

          在这里插入图片描述

          使用 HF Trainer + wandb 训练

          • 我们调用官方给的样例

            我们发现其实新添了这几个内容:

            WANDB_PROJECT 环境变量:项目名

            WANDB_LOG_MODEL 环境变量:是否保存中继到wandb

            WANDB_WATCH环境变量

          • 在 TrainingArguments 中,设置了 report_to="wandb"

            最后调用 wandb.finish() ,整体变化不大

            # This script needs these libraries to be installed: 
            #   numpy, transformers, datasets
            import wandb 
            import os
            import numpy as np
            from datasets import load_dataset
            from transformers import TrainingArguments, Trainer
            from transformers import AutoTokenizer, AutoModelForSequenceClassification
            # 设置GPU编号
            os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
            os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
            def tokenize_function(examples):
                return tokenizer(examples["text"], padding="max_length", truncation=True)
            def compute_metrics(eval_pred):
                logits, labels = eval_pred
                predictions = np.argmax(logits, axis=-1)
                return {"accuracy": np.mean(predictions == labels)}
            print("Loading Dataset")
            # download prepare the data
            dataset = load_dataset("yelp_review_full")
            print("Loading Tokenizer")
            tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
            small_train_dataset = dataset["train"].shuffle(seed=42).select(range(1000))
            small_eval_dataset = dataset["test"].shuffle(seed=42).select(range(300))
            small_train_dataset = small_train_dataset.map(tokenize_function, batched=True)
            small_eval_dataset = small_train_dataset.map(tokenize_function, batched=True)
            print("Loading Model")
            # download the model
            model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=5)
            # set the wandb project where this run will be logged
            os.environ["WANDB_PROJECT"]="my-awesome-project"
            # save your trained model checkpoint to wandb
            os.environ["WANDB_LOG_MODEL"]="true"
            # turn off watch to log faster
            os.environ["WANDB_WATCH"]="false"
            # pass "wandb" to the 'report_to' parameter to turn on wandb logging
            training_args = TrainingArguments(
                output_dir='models',
                report_to="wandb",
                logging_steps=5, 
                per_device_train_batch_size=32,
                per_device_eval_batch_size=32,
                evaluation_strategy="steps",
                eval_steps=20,
                max_steps = 100,
                save_steps = 100
            )
            print("Loading Trainer")
            # define the trainer and start training
            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=small_train_dataset,
                eval_dataset=small_eval_dataset,
                compute_metrics=compute_metrics,
            )
            print("Training")
            trainer.train()
            # [optional] finish the wandb run, necessary in notebooks
            wandb.finish()
            
            • 在 wandb 网站中

              我们可以打开该 project。每一次运行相当于一次 run,我这里跑了三次所以就有三条线。

              这里主要是看 eval 验证集和 train 训练集的一些参数。

              在这里插入图片描述

              在这里插入图片描述

            • 我们可以删掉不关心的面板,或者增添一个想看的面板

              但如果两个参数的值域变化比较大的话,在一个图里面比较难看清,所以比较相关的参数才建议放在一个图里。

              在这里插入图片描述

              低级 API

              • 这上面是封装比较高级的 API,一般我们也都配合 transformers 库去用

                如果想用比较原生的 API,一般用法如下:

                首先调用 wandb.init() 方法

                然后使用 wandb.log(dict) 输出你要可视化的参数即可。

                # train.py
                import wandb
                import random  # for demo script
                wandb.login()
                epochs = 10
                lr = 0.01
                run = wandb.init(
                    # Set the project where this run will be logged
                    project="my-awesome-project",
                    # Track hyperparameters and run metadata
                    config={
                        "learning_rate": lr,
                        "epochs": epochs,
                    },
                )
                offset = random.random() / 5
                print(f"lr: {lr}")
                # simulating a training run
                for epoch in range(2, epochs):
                    acc = 1 - 2**-epoch - random.random() / epoch - offset
                    loss = 2**-epoch + random.random() / epoch + offset
                    print(f"epoch={epoch}, accuracy={acc}, loss={loss}")
                    wandb.log({"accuracy": acc, "loss": loss})
                # run.log_code()
                
微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon