torch

torch.nn

torch.nn.MSELoss

均方误差损失

参数解释
reduction汇总方法
none:返回每个元素的损失值
mean:返回所有元素的损失的平均值(默认)
sum:返回所有元素的损失值的和

torch.nn.CrossEntropyLoss

交叉熵损失

参数解释
reduction汇总方法
none:返回每个元素的损失值
mean:返回所有元素的损失的平均值(默认)
sum:返回所有元素的损失值的和

torch.optim

torch.optim.SGD()

随机梯度下降优化器

输入

参数解释
params要优化的参数
lr学习率
momentum用于加速梯度下降,并避免在鞍点附近震荡。默认值为 0
dampening动量的衰减因子。默认值为 0
weight_decay权重衰减(L2 正则化),用于防止过拟合。默认值为 0
nesterov是否使用 Nesterov 动量,True 表示使用。默认值为 False

返回

优化器对象

torch.utils

torch.utils.data

torch.utils.data.Dataloader()

批量加载数据,生成一个迭代器,随机抽取批量大小的数据

输入

参数解释
dataset数据集
batch_size=32批次大小
shuffle=True是否打乱
num_workers=4使用的进程数

返回

dataloader 迭代器

math

math.gamma

numpy

np.power

计算指数,可放入列表

np.arange

生成一维列表,含头不含尾
arange(start,stop,step,dtype)

参数解释
start开始
stop结束
step间隔
dtype数据类型