pytorch教程tensor4
作者:
晓博
,
2022-01-10 10:25:42
,
所有人可见
,
阅读 188
import torch
x = torch.arange(9)
x_3x3 = x.view(3, 3) #要求内存是连续的
x_3x3 = x.reshape(3, 3) #内存可以不连续
print(x_3x3.shape)
print(x_3x3)
y = x_3x3.t()
print(y)
print(y.contiguous().view(9))
x1 = torch.rand((2, 5))
x2 = torch.rand((2, 5))
print(torch.cat((x1, x2), dim = 0).shape)#行拼接
print(torch.cat((x1, x2), dim = 1).shape)#列拼接
z = x1.view(-1) #行连续
print(z.shape)
batch = 64
x = torch.rand((batch, 2, 5))
z = x.view(batch, -1)
print(z.shape)
z = x.permute(0, 2, 1) #交换维度,按索引分配
print(z.shape)
x = torch.arange(10)
print(x.unsqueeze(0).shape)#(1, 10)
print(x.unsqueeze(1).shape)#(10, 1)
x=torch.arange(10).unsqueeze(0).unsqueeze(1)#在索引处增加维度(1, 1, 10)
print(x.shape)
z = x.squeeze(1)#减少一个维度
print(z.shape)