Computer >> Máy Tính >  >> Lập trình >> Python

Làm thế nào để tìm phần tử thứ k và k đầu của một tensor trong PyTorch?

PyTorch cung cấp phương thức torch.kthvalue () để tìm phần tử thứ k của một tensor. Nó trả về giá trị của phần tử thứ k của tensor được sắp xếp theo thứ tự tăng dần và chỉ số của phần tử trong tensor ban đầu.

torch.topk () phương pháp được sử dụng để tìm phần tử "k" hàng đầu. Nó trả về phần tử "k" hàng đầu hoặc "k" lớn nhất trong tensor.

Các bước

  • Nhập thư viện được yêu cầu. Trong tất cả các ví dụ Python sau, thư viện Python bắt buộc là torch . Đảm bảo rằng bạn đã cài đặt nó.

  • Tạo một tensor PyTorch và in nó.

  • Tính toán torch.kthvalue (đầu vào, k) . Nó trả về hai tenxơ. Gán hai tenxơ này cho hai biến mới "value" "chỉ mục" . Ở đây, đầu vào là một tensor và k là một số nguyên.

  • Tính toán torch.topk (đầu vào, k) . Nó trả về hai tenxơ. Tensor đầu tiên có giá trị của các phần tử "k" hàng đầu và tensor thứ hai có chỉ số của các phần tử này trong tensor ban đầu. Gán hai tenxơ này cho các biến "giá trị" mới và "chỉ số" .

  • In giá trị và chỉ số của phần tử thứ k của tensor cũng như các giá trị và chỉ số của phần tử "k" trên cùng của tensor.

Ví dụ 1

Chương trình python này cho biết cách tìm phần tử thứ k của một tensor.

# Python program to find k-th element of a tensor
# import necessary library
import torch

# Create a 1D tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# Find the 3rd element in sorted tensor. First it sorts the
# tensor in ascending order then returns the kth element value
# from sorted tensor and the index of element in original tensor
value, index = torch.kthvalue(T, 3)

# print 3rd element with value and index
print("3rd element value:", value)
print("3rd element index:", index)

Đầu ra

Original Tensor:
   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
3rd element value: tensor(2.3340)
3rd element index: tensor(0)

Ví dụ 2

Chương trình Python sau đây cho biết cách tìm phần tử "k" hàng đầu hoặc "k" lớn nhất của một tensor.

# Python program to find to top k elements of a tensor
# import necessary library
import torch

# Create a 1D tensor
T = torch.Tensor([2.334,4.433,-4.33,-0.433,5, 4.443])
print("Original Tensor:\n", T)

# Find the top k=2 or 2 largest elements of the tensor
# returns the 2 largest values and their indices in original
# tensor
values, indices = torch.topk(T, 2)

# print top 2 elements with value and index
print("Top 2 element values:", values)
print("Top 2 element indices:", indices)

Đầu ra

Original Tensor:
   tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430])
Top 2 element values: tensor([5.0000, 4.4430])
Top 2 element indices: tensor([4, 5])