那曲檬骨新材料有限公司

電子發燒友App

硬聲App

0
  • 聊天消息
  • 系統消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術視頻
  • 寫文章/發帖/加入社區
會員中心
創作中心

完善資料讓更多小伙伴認識你,還能領取20積分哦,立即完善>

3天內不再提示
電子發燒友網>電子資料下載>電子資料>PyTorch教程4.3之基本分類模型

PyTorch教程4.3之基本分類模型

2023-06-05 | pdf | 0.13 MB | 次下載 | 免費

資料介紹

您可能已經注意到,在回歸的情況下,從頭開始的實現和使用框架功能的簡潔實現非常相似。分類也是如此。由于本書中的許多模型都處理分類,因此值得添加專門支持此設置的功能。本節為分類模型提供了一個基類,以簡化以后的代碼。

import torch
from d2l import torch as d2l
from mxnet import autograd, gluon, np, npx
from d2l import mxnet as d2l

npx.set_np()
from functools import partial
import jax
import optax
from jax import numpy as jnp
from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf
from d2l import tensorflow as d2l

4.3.1. Classifier_

我們在下面定義Classifier類。在中,validation_step我們報告了驗證批次的損失值和分類準確度。我們為每個批次繪制一個更新num_val_batches這有利于在整個驗證數據上生成平均損失和準確性。如果最后一批包含的示例較少,則這些平均數并不完全正確,但我們忽略了這一微小差異以保持代碼簡單。

class Classifier(d2l.Module): #@save
  """The base class of classification models."""
  def validation_step(self, batch):
    Y_hat = self(*batch[:-1])
    self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
    self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

We define the Classifier class below. In the validation_step we report both the loss value and the classification accuracy on a validation batch. We draw an update for every num_val_batches batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the last batch contains fewer examples, but we ignore this minor difference to keep the code simple.

class Classifier(d2l.Module): #@save
  """The base class of classification models."""
  def validation_step(self, batch):
    Y_hat = self(*batch[:-1])
    self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
    self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

We define the Classifier class below. In the validation_step we report both the loss value and the classification accuracy on a validation batch. We draw an update for every num_val_batches batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the last batch contains fewer examples, but we ignore this minor difference to keep the code simple.

We also redefine the training_step method for JAX since all models that will subclass Classifier later will have a loss that returns auxiliary data. This auxiliary data can be used for models with batch normalization (to be explained in Section 8.5), while in all other cases we will make the loss also return a placeholder (empty dictionary) to represent the auxiliary data.

class Classifier(d2l.Module): #@save
  """The base class of classification models."""
  def training_step(self, params, batch, state):
    # Here value is a tuple since models with BatchNorm layers require
    # the loss to return auxiliary data
    value, grads = jax.value_and_grad(
      self.loss, has_aux=True)(params, batch[:-1], batch[-1], state)
    l, _ = value
    self.plot("loss", l, train=True)
    return value, grads

  def validation_step(self, params, batch, state):
    # Discard the second returned value. It is used for training models
    # with BatchNorm layers since loss also returns auxiliary data
    l, _ = self.loss(params, batch[:-1], batch[-1], state)
    self.plot('loss', l, train=False)
    self.plot('acc', self.accuracy(params, batch[:-1], batch[-1], state),
         train=False)

We define the Classifier class below. In the validation_step we report both the loss value and the classification accuracy on a validation batch. We draw an update for every num_val_batches batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the last batch contains fewer examples, but we ignore this minor difference to keep the code simple.

class Classifier(d2l.Module): #@save
  """The base class of classification models."""
  def validation_step(self, batch):
    Y_hat = self(*batch[:-1])
    self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
    self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

默認情況下,我們使用隨機梯度下降優化器,在小批量上運行,就像我們在線性回歸的上下文中所做的那樣。

@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
  return torch.optim.SGD(self.parameters(), lr=self.lr)
@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
  params = self.parameters()
  if isinstance(params, list):
    return d2l.SGD(params, self.lr)
  return gluon.Trainer(params, 'sgd', {'learning_rate': self.lr})
@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
  return optax.sgd(self.lr)
@d2l.add_to_class(d2l.Module) #@save
def configure_optimizers(self):
  return tf.keras.optimizers.SGD(self.lr)

4.3.2. 準確性

給定預測概率分布y_hat,每當我們必須輸出硬預測時,我們通常會選擇預測概率最高的類別。事實上,許多應用程序需要我們做出選擇。例如,Gmail 必須將電子郵件分類為“主要”、“社交”、“更新”、“論壇”或“垃圾郵件”。它可能會在內部估計概率,但最終它必須在類別中選擇一個。

當預測與標簽 class 一致時y,它們是正確的。分類準確度是所有正確預測的分數。盡管直接優化精度可能很困難(不可微分),但它通常是我們最關心的性能指標。它通常是基準測試中的相關數量因此,我們幾乎總是在訓練分類器時報告它。

準確度計算如下。首先,如果y_hat是一個矩陣,我們假設第二個維度存儲每個類別的預測分數。我們使用argmax每行中最大條目的索引來獲取預測類。然后我們將預測的類別與真實的元素進行比較y由于相等運算符== 對數據類型敏感,因此我們轉換 的y_hat數據類型以匹配 的數據類型y結果是一個包含條目 0(假)和 1(真)的張量。求和得出正確預測的數量。

@d2l.add_to_class(Classifier) #@save
def accuracy(self, Y_hat, Y, averaged=True):
  """Compute the number of correct predictions."""
  Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
  preds = Y_hat.argmax(axis=1).type(Y.dtype)
  compare = (preds == Y.reshape(-1)).type(torch.float32)
  return compare.mean() if averaged else compare
@d2l.add_to_class(Classifier) #@save
def accuracy(self, Y_hat, Y, averaged=True):
  """Compute the number of correct predictions."""
  Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
  preds = Y_hat.argmax(axis=1).astype(Y.dtype)
  compare = (preds == Y.reshape(-1)).astype(np.float32)
  return compare.mean() if averaged else compare

@d2l.add_to_class(d2l.Module) #@save
def get_scratch_params(self):
  params = []
  for attr in dir(self):
    a = getattr(self, attr)
    if isinstance(a, np.ndarray):
      params.append(a)
    if isinstance(a, d2l.Module):
      params.extend(a.get_scratch_params())
  return params

@d2l.add_to_class(d2l.Module) #@save
def parameters(self):
  params = self.collect_params()
  return params if isinstance(params, gluon.parameter.ParameterDict) and

下載該資料的人也在下載 下載該資料的人還在閱讀
更多 >

評論

查看更多

下載排行

本周

  1. 1A7159和A7139射頻芯片的資料免費下載
  2. 0.20 MB   |  55次下載  |  5 積分
  3. 2PIC12F629/675 數據手冊免費下載
  4. 2.38 MB   |  36次下載  |  5 積分
  5. 3PIC16F716 數據手冊免費下載
  6. 2.35 MB   |  18次下載  |  5 積分
  7. 4dsPIC33EDV64MC205電機控制開發板用戶指南
  8. 5.78MB   |  8次下載  |  免費
  9. 5STC15系列常用寄存器匯總免費下載
  10. 1.60 MB   |  7次下載  |  5 積分
  11. 6模擬電路仿真實現
  12. 2.94MB   |  4次下載  |  免費
  13. 7PCB圖繪制實例操作
  14. 2.92MB   |  2次下載  |  免費
  15. 8零死角玩轉STM32F103—指南者
  16. 26.78 MB   |  1次下載  |  1 積分

本月

  1. 1ADI高性能電源管理解決方案
  2. 2.43 MB   |  452次下載  |  免費
  3. 2免費開源CC3D飛控資料(電路圖&PCB源文件、BOM、
  4. 5.67 MB   |  141次下載  |  1 積分
  5. 3基于STM32單片機智能手環心率計步器體溫顯示設計
  6. 0.10 MB   |  137次下載  |  免費
  7. 4A7159和A7139射頻芯片的資料免費下載
  8. 0.20 MB   |  55次下載  |  5 積分
  9. 5PIC12F629/675 數據手冊免費下載
  10. 2.38 MB   |  36次下載  |  5 積分
  11. 6如何正確測試電源的紋波
  12. 0.36 MB   |  19次下載  |  免費
  13. 7PIC16F716 數據手冊免費下載
  14. 2.35 MB   |  18次下載  |  5 積分
  15. 8Q/SQR E8-4-2024乘用車電子電器零部件及子系統EMC試驗方法及要求
  16. 1.97 MB   |  8次下載  |  10 積分

總榜

  1. 1matlab軟件下載入口
  2. 未知  |  935121次下載  |  10 積分
  3. 2開源硬件-PMP21529.1-4 開關降壓/升壓雙向直流/直流轉換器 PCB layout 設計
  4. 1.48MB  |  420062次下載  |  10 積分
  5. 3Altium DXP2002下載入口
  6. 未知  |  233088次下載  |  10 積分
  7. 4電路仿真軟件multisim 10.0免費下載
  8. 340992  |  191367次下載  |  10 積分
  9. 5十天學會AVR單片機與C語言視頻教程 下載
  10. 158M  |  183335次下載  |  10 積分
  11. 6labview8.5下載
  12. 未知  |  81581次下載  |  10 積分
  13. 7Keil工具MDK-Arm免費下載
  14. 0.02 MB  |  73810次下載  |  10 積分
  15. 8LabVIEW 8.6下載
  16. 未知  |  65988次下載  |  10 積分
大发888破解方法| 百家乐官网娱乐城地址| 太阳百家乐网| 波音百家乐网上娱乐| 速博百家乐官网的玩法技巧和规则| 百家乐官网赌经| 百家乐官网园胎教网| 百家乐游戏| 六合彩开奖日期| e娱乐城棋牌| 澳门顶级赌场网址| 大发888游戏下载平台| 真人游戏平台| 大发888娱乐城贴吧| 新全讯网网址xb112| 百家乐手论坛48491| 巴厘岛百家乐的玩法技巧和规则 | 衡阳县| 五寨县| 哪里有百家乐官网代理| 百家乐官网的关键技巧| 百家乐官网哪家信誉好| 百家乐官网游戏群号| 马牌百家乐官网现金网| 博九网百家乐官网现金网| 百家乐官网胜率被控制| 视频百家乐官网平台出租| 易胜博百家乐官网娱乐城| 百家乐官网赌博千术| 百家乐官网路单| 立博百家乐官网的玩法技巧和规则 | 百家乐官网电脑游戏机投注法实例| 百家乐官网赌博大全| 巴特百家乐官网的玩法技巧和规则| 都坊百家乐官网的玩法技巧和规则 | 百家乐官网推荐怎么看| 百家乐官网分析仪博彩正网| 百家乐官网赌博技巧大全| 大哥大百家乐官网的玩法技巧和规则| 百家乐玩法与规则| 雅加达百家乐的玩法技巧和规则|