pytorch比较操作

慈云数据 2024-05-29 技术支持 34 0

文章目录

  • 常用的比较操作
    • 1.torch.allclose()
    • 2.torch.argsort()
    • 3.torch.eq()
    • 4.torch.equal()
    • 5.torch.greater_equal()
    • 6.torch.gt()
    • 7.torch.isclose()
    • 8.torch.isfinite()
    • 9.torch.isif()
    • 10.torch.isposinf()
    • 11.torch.isneginf()
    • 12.torch.isnan()
    • 13.torch.kthvalue()
    • 14.torch.less_equal()
    • 15.torch.maximum()
    • 16.torch.fmax()
    • 17.torch.ne()
    • 18.torch.sort()
    • 19.torch.topk()

      常用的比较操作

      在这里插入图片描述


      1.torch.allclose()

        torch.allclose() 是 PyTorch 中用于比较两个张量是否在给定的容差范围内近似相等的函数。它可以用于比较浮点数张量之间的相等性。

      torch.allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False)
      """
      input:第一个输入张量。
      other:第二个输入张量。
      rtol:相对容差(relative tolerance),默认为 1e-05。
      atol:绝对容差(absolute tolerance),默认为 1e-08。
      equal_nan:一个布尔值,指示是否将 NaN 视为相等,默认为 False。
      """
      
      import torch
      # 比较两个张量是否近似相等
      x = torch.tensor([1.0, 2.0, 3.0])
      y = torch.tensor([1.0001, 2.0002, 3.0003])
      is_close = torch.allclose(x, y, rtol=1e-03, atol=1e-05)
      print(is_close)# True
      

      2.torch.argsort()

      torch.argsort() 是 PyTorch 中用于对张量进行排序返回排序后的索引的函数。它返回一个新的张量,其中每个元素表示原始张量中对应位置的元素在排序后的顺序中的索引值。

      torch.argsort(input, dim=-1, descending=False, *, out=None)
      """
      input:输入张量。
      dim:指定排序的维度,默认为 -1,表示最后一个维度。
      descending:一个布尔值,指示是否按降序排序,默认为 False。
      out:可选参数,用于指定输出张量的位置。
      """
      
      import torch
      # 对张量进行排序并返回索引
      x = torch.tensor([3, 1, 4, 2])
      sorted_indices = torch.argsort(x)
      print(sorted_indices)
      # tensor([1, 3, 0, 2])
      

      3.torch.eq()

        torch.eq() 是 PyTorch 中用于执行元素级别相等性比较的函数。它比较两个张量的对应元素,并返回一个新的布尔张量,其中元素为 True 表示对应位置的元素相等,元素为 False 表示对应位置的元素不相等。

      torch.eq(input, other, out=None)
      """
      input:第一个输入张量。
      other:第二个输入张量。
      out:可选参数,用于指定输出张量的位置。
      """
      
      import torch
      # 执行元素级别的相等性比较
      x = torch.tensor([1, 2, 3])
      y = torch.tensor([1, 2, 4])
      result = torch.eq(x, y)
      print(result)# tensor([ True,  True, False])
      

      4.torch.equal()

      torch.equal() 是 PyTorch 中用于检查两个张量是否在元素级别上完全相等的函数。它返回一个布尔值,指示两个张量是否具有相同的形状和相同的元素值。

      torch.equal(input, other)
      """
      input:第一个输入张量。
      other:第二个输入张量。
      """
      
      import torch
      # 检查两个张量是否完全相等
      x = torch.tensor([1, 2, 3])
      y = torch.tensor([1, 2, 3])
      is_equal = torch.equal(x, y)
      print(is_equal)# True
      

      5.torch.greater_equal()

      torch.greater_equal() 是 PyTorch 中用于执行元素级别的大于等于比较的函数。它比较两个张量的对应元素,并返回一个新的布尔张量,其中元素为 True 表示对应位置的元素大于或等于,元素为 False 表示对应位置的元素小于。

      torch.greater_equal(input, other, out=None)
      """
      input:第一个输入张量。
      other:第二个输入张量。
      out:可选参数,用于指定输出张量的位置。
      """
      
      import torch
      # 执行元素级别的大于等于比较
      x = torch.tensor([1, 2, 3])
      y = torch.tensor([2, 2, 2])
      result = torch.greater_equal(x, y)
      print(result)
      
      tensor([False,  True,  True])
      

      6.torch.gt()

      torch.gt() 是 PyTorch 中用于执行元素级别的大于比较的函数。它比较两个张量的对应元素,并返回一个新的布尔张量,其中元素为 True 表示对应位置的元素大于,元素为 False 表示对应位置的元素小于或等于。

      torch.gt(input, other, out=None)
      """
      input:第一个输入张量。
      other:第二个输入张量。
      out:可选参数,用于指定输出张量的位置。
      """
      
      import torch
      # 执行元素级别的大于比较
      x = torch.tensor([1, 2, 3])
      y = torch.tensor([2, 2, 2])
      result = torch.gt(x, y)
      print(result)#tensor([False, False,  True])
      

      7.torch.isclose()

      torch.isclose() 是 PyTorch 中用于比较两个张量是否在给定的容差范围内近似相等的函数。它可以用于比较浮点数张量之间的相等性。

      torch.isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False)
      """
      input:第一个输入张量。
      other:第二个输入张量。
      rtol:相对容差(relative tolerance),默认为 1e-05。
      atol:绝对容差(absolute tolerance),默认为 1e-08。
      equal_nan:一个布尔值,指示是否将 NaN 视为相等,默认为 False。
      """
      
      import torch
      # 比较两个张量是否近似相等
      x = torch.tensor([1.0, 2.0, 3.0])
      y = torch.tensor([1.0001, 2.0002, 3.0003])
      is_close = torch.isclose(x, y, rtol=1e-03, atol=1e-05)
      print(is_close)
      
      tensor([True, True, True])
      

      8.torch.isfinite()

      torch.isfinite() 是 PyTorch 中用于检查张量中的元素是否为有限数(finite number)的函数。它返回一个新的布尔张量,其中每个元素表示对应位置的元素是否为有限数。

      torch.isfinite(input, out=None)
      """
      input:输入张量。
      out:可选参数,用于指定输出张量的位置。
      """
      
      import torch
      # 检查张量中的元素是否为有限数
      x = torch.tensor([1.0, float('inf'), float('-inf'), float('nan')])
      is_finite = torch.isfinite(x)
      print(is_finite)# tensor([ True, False, False, False])
      

      9.torch.isif()

      torch.isinf() 是 PyTorch 中用于检查张量中的元素是否为无穷大的函数。它返回一个新的布尔张量,其中每个元素表示对应位置的元素是否为无穷大。

      torch.isinf(input, out=None)
      """
      input:输入张量。
      out:可选参数,用于指定输出张量的位置。
      """
      
      import torch
      # 检查张量中的元素是否为无穷大
      x = torch.tensor([1.0, float('inf'), float('-inf'), float('nan')])
      is_inf = torch.isinf(x)
      print(is_inf)
      
      tensor([False,  True,  True, False])
      

      10.torch.isposinf()

      torch.isposinf() 是 PyTorch 中用于检查张量中的元素是否为正无穷大的函数。它返回一个新的布尔张量,其中每个元素表示对应位置的元素是否为正无穷大。

      torch.isposinf(input, out=None)
      """
      input:输入张量。
      out:可选参数,用于指定输出张量的位置。
      """
      
      import torch
      # 检查张量中的元素是否为正无穷大
      x = torch.tensor([1.0, float('inf'), float('-inf'), float('nan')])
      is_posinf = torch.isposinf(x)
      print(is_posinf)# tensor([False,  True, False, False])
      

      11.torch.isneginf()

      torch.isneginf() 是 PyTorch 中用于检查张量中的元素是否为负无穷大的函数。它返回一个新的布尔张量,其中每个元素表示对应位置的元素是否为负无穷大。

      torch.isneginf(input, out=None)
      """
      input:输入张量。
      out:可选参数,用于指定输出张量的位置。
      """
      
      import torch
      # 检查张量中的元素是否为负无穷大
      x = torch.tensor([1.0, float('inf'), float('-inf'), float('nan')])
      is_neginf = torch.isneginf(x)
      print(is_neginf)# tensor([False, False,  True, False])
      

      12.torch.isnan()

      torch.isnan() 是 PyTorch 中用于检查张量中的元素是否为 NaN(Not a Number)的函数。它返回一个新的布尔张量,其中每个元素表示对应位置的元素是否为 NaN。

      torch.isnan(input, out=None)
      """
      input:输入张量。
      out:可选参数,用于指定输出张量的位置。
      """
      
      import torch
      # 检查张量中的元素是否为 NaN
      x = torch.tensor([1.0, float('inf'), float('-inf'), float('nan')])
      is_nan = torch.isnan(x)
      print(is_nan)# tensor([False, False, False,  True])
      

      13.torch.kthvalue()

      torch.kthvalue() 函数用于找出张量中的第 k 小值,而 torch.topk() 函数用于找出张量中的前 k 个最大值(或最小值)及其对应的索引。

      torch.topk(input, k, dim=None, largest=True, sorted=True, out=None)
      """
      input:输入张量。
      k:要找到的最大值(或最小值)的数量。
      dim:可选参数,指定在哪个维度上进行查找。如果未指定,则默认在最后一个维度上查找。
      largest:可选参数,指定是找到最大值还是最小值。默认为 True,表示找到最大值。
      sorted:可选参数,指定结果张量是否按降序排列。默认为 True。
      out:可选参数,用于指定输出张量的位置。
      """
      
      import torch
      # 找出张量中的前 3 个最大值及其索引
      x = torch.tensor([1, 3, 2, 4, 6, 5])
      values, indices = torch.topk(x, k=3)
      print(values)#tensor([6, 5, 4])
      print(indices)#tensor([4, 5, 3])
      

      14.torch.less_equal()

      torch.less_equal() 是 PyTorch 中用于执行逐元素的小于等于(

微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon