Computer >> Máy Tính >  >> Lập trình >> Python

Làm cách nào để truy cập siêu dữ liệu của tensor trong PyTorch?

Chúng tôi truy cập kích thước (hoặc hình dạng) của tensor và số phần tử trong tensor dưới dạng siêu dữ liệu của tensor. Để truy cập kích thước của tensor, chúng tôi sử dụng .size () và hình dạng của tensor được truy cập bằng cách sử dụng .shape .

Cả .size () .shape tạo ra cùng một kết quả. Chúng tôi sử dụng torch.numel () hàm để tìm tổng số phần tử trong tensor.

Các bước

  • Nhập thư viện được yêu cầu. Ở đây, thư viện bắt buộc là torch . Đảm bảo rằng bạn đã cài đặt torch .

  • Xác định tensor PyTorch.

  • Tìm siêu dữ liệu của tensor. Sử dụng .size () .shape để truy cập kích thước và hình dạng của tensor. Sử dụng torch.numel () để truy cập số phần tử trong tensor.

  • In tensor và siêu dữ liệu để hiểu rõ hơn.

Ví dụ 1

# Python Program to access meta-data of a Tensor
# import necessary libraries
import torch

# Create a tensor of size 4x3
T = torch.Tensor([[1,2,3],[2,1,3],[2,3,5],[5,6,4]])
print("T:\n", T)

# Find the meta-data of tensor
# Find the size of the above tensor "T"
size_T = T.size()
print("size of tensor T:\n", size_T)

# Other method to get size using .shape
print("Shape of tensor:\n", T.shape)

# Find the number of elements in the tensor "T"
num_T = torch.numel(T)
print("Number of elements in tensor T:\n", num_T)

Đầu ra

Khi bạn chạy mã Python 3 ở trên, nó sẽ tạo ra kết quả sau.

T:
tensor([[1., 2., 3.],
         [2., 1., 3.],
         [2., 3., 5.],
         [5., 6., 4.]])
size of tensor T:
torch.Size([4, 3])
Shape of tensor:
torch.Size([4, 3])
Number of elements in tensor T:
12

Ví dụ 2

# Python Program to access meta-data of a Tensor
# import the libraries
import torch

# Create a tensor of random numbers
T = torch.randn(4,3,2)
print("T:\n", T)

# Find the meta-data of tensor
# Find the size of the above tensor "T"
size_T = T.size()
print("size of tensor T:\n", size_T)

# Other method to get size using .shape
print("Shape of tensor:\n", T.shape)

# Find the number of elements in the tensor "T"
num_T = torch.numel(T)
print("Number of elements in tensor T:\n", num_T)

Đầu ra

Khi bạn chạy mã Python 3 ở trên, nó sẽ tạo ra kết quả sau.

T:
tensor([[[-1.1806, 0.5569],
         [ 2.2237, 0.9709],
         [ 0.4775, -0.2491]],
         [[-0.9703, 1.9916],
         [ 0.1998, -0.6501],
         [-0.7489, -1.3013]],
         [[ 1.3191, 2.0049],
         [-0.1195, 0.1860],
         [-0.6061, -1.2451]],
         [[-0.6044, 0.6153],
         [-2.2473, -0.1531],
         [ 0.5341, 1.3697]]])
size of tensor T:
torch.Size([4, 3, 2])
Shape of tensor:
torch.Size([4, 3, 2])
Number of elements in tensor T:
24