pytorch使用

发布时间 2023-09-19 00:17:16作者: 失控D大白兔
import torch
import numpy as np

1. 张量(tensor)

1.1 初始化

data = [[1, 2],[3, 4]]
x_data = torch.tensor(data) #从列表初始化
np_array = np.array(data)
x_data = torch.from_numpy(np_array) #从numpy初始化
x_ones = torch.ones_like(x_data)    #保持相同规模,全部置1
x_rand = torch.rand_like(x_data, dtype=torch.float) #保持相同规模,生成随机数

根据维度进行初始化

shape = (2,3,)
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)

1.2 张量运算

像numpy一样进行索引和切片

tensor[0] #第一行
tensor[:,0] #第一列
tensor[...,-1] #最后一列
t1 = torch.cat([tensor, tensor, tensor], dim=1) #进行拼接

算数运算