pytorch教程tensor3
作者:
晓博
,
2022-01-10 10:24:38
,
所有人可见
,
阅读 168
import torch
batch_size = 10
features = 25
x = torch.rand((batch_size, features))
print(x[0].shape)
print(x[:, 0].shape)
print(x[2, 0:10])
x[0, 0] = 100
x = torch.arange(10)
indices = [2, 5, 8]
print(x[indices])
x = torch.rand((3, 5))
rows = torch.tensor([1, 0])
cols = torch.tensor([4, 0])
print(x[rows, cols].shape)
x = torch.arange(10)
print(x[(x < 2 ) | ( x > 8 )] )
print(x[x.remainder(2) == 0]) # 能被2整除
print(torch.where(x > 5, x, x*2)) #当x > 5 输出x 否则输出x的平方
print(torch.tensor([0, 0, 1, 2, 2, 3, 4]).unique())#去重,输出0,1,2,3,4
print(x.ndimension())#输出数组的维度
print(x.numel())#输出数组的数量