torch
torch.nn
torch.nn.MSELoss
均方误差损失
参数 | 解释 |
---|---|
reduction | 汇总方法 none:返回每个元素的损失值 mean:返回所有元素的损失的平均值(默认) sum:返回所有元素的损失值的和 |
torch.nn.CrossEntropyLoss
交叉熵损失
参数 | 解释 |
---|---|
reduction | 汇总方法 none:返回每个元素的损失值 mean:返回所有元素的损失的平均值(默认) sum:返回所有元素的损失值的和 |
torch.optim
torch.optim.SGD()
随机梯度下降优化器
输入
参数 | 解释 |
---|---|
params | 要优化的参数 |
lr | 学习率 |
momentum | 用于加速梯度下降,并避免在鞍点附近震荡。默认值为 0 |
dampening | 动量的衰减因子。默认值为 0 |
weight_decay | 权重衰减(L2 正则化),用于防止过拟合。默认值为 0 |
nesterov | 是否使用 Nesterov 动量,True 表示使用。默认值为 False |
返回
优化器对象
torch.utils
torch.utils.data
torch.utils.data.Dataloader()
批量加载数据,生成一个迭代器,随机抽取批量大小的数据
输入
参数 | 解释 |
---|---|
dataset | 数据集 |
batch_size=32 | 批次大小 |
shuffle=True | 是否打乱 |
num_workers=4 | 使用的进程数 |
返回
dataloader 迭代器
math
math.gamma
numpy
np.power
计算指数,可放入列表
np.arange
生成一维列表,含头不含尾
arange(start,stop,step,dtype)
参数 | 解释 |
---|---|
start | 开始 |
stop | 结束 |
step | 间隔 |
dtype | 数据类型 |