Computer >> Máy Tính >  >> Lập trình >> Python

Làm cách nào để lấy kiểu dữ liệu của tensor trong PyTorch?

Một tensor PyTorch là đồng nhất, tức là, tất cả các phần tử của tensor có cùng kiểu dữ liệu. Chúng tôi có thể truy cập loại dữ liệu của tensor bằng cách sử dụng ".dtype" thuộc tính của tensor. Nó trả về kiểu dữ liệu của tensor.

Các bước

  • Nhập thư viện được yêu cầu. Trong tất cả các ví dụ Python sau, thư viện Python bắt buộc là torch . Đảm bảo rằng bạn đã cài đặt nó.

  • Tạo một tensor và in nó.

  • Tính toán T.dtype . Ở đây T là tensor mà chúng ta muốn lấy kiểu dữ liệu.

  • In kiểu dữ liệu của tensor.

Ví dụ 1

Chương trình Python sau đây cho biết cách lấy kiểu dữ liệu của một tensor.

# Import the library
import torch

# Create a tensor of random numbers of size 3x4
T = torch.randn(3,4)
print("Original Tensor T:\n", T)

# Get the data type of above tensor
data_type = T.dtype

# Print the data type of the tensor
print("Data type of tensor T:\n", data_type)

Đầu ra

Original Tensor T:
tensor([[ 2.1768, -0.1328, 0.8155, -0.7967],
         [ 0.1194, 1.0465, 0.0779, 0.9103],
         [-0.1809, 1.8085, 0.8393, -0.2463]])
Data type of tensor T:
torch.float32

Ví dụ 2

# Python program to get data type of a tensor
# Import the library
import torch

# Create a tensor of random numbers of size 3x4
T = torch.Tensor([1,2,3,4])
print("Original Tensor T:\n", T)

# Get the data type of above tensor
data_type = T.dtype

# Print the data type of the tensor
print("Data type of tensor T:\n", data_type)

Đầu ra

Original Tensor T:
   tensor([1., 2., 3., 4.])
Data type of tensor T:
   torch.float32