torch
torch.clamp
限制张量的最大最小值
如果小于min则调整为min
如果大于max则调整为max
参数 | 解释 |
---|---|
input | 需要限制的张量 |
min | 最小值 |
max | 最大值 |
torch.nn
torch.nn.Sequential
class Sequential(nn.Module):
def __init__(self, *args):
super().__init__()
for idx, module in enumerate(args):
# 这里,module是Module子类的一个实例。我们把它保存在'Module'类的成员
# 变量_modules中。_module的类型是OrderedDict
self._modules[str(idx)] = module
def forward(self, X):
# OrderedDict保证了按照成员添加的顺序遍历它们
for block in self._modules.values():
X = block(X)
return X
torch.nn.Conv2d
class Conv2D(nn.Module):
def __init__(self, kernel_size):
super().__init__()
self.weight = nn.Parameter(torch.rand(kernel_size))
self.bias = nn.Parameter(torch.zeros(1))
def forward(self, x):
return corr2d(x, self.weight) + self.bias
参数 | 解释 |
---|---|
in_channels | 输入通道数 |
out_channels | 输出通道数(卷积核个数) |
kernel_size | 卷积核大小(可设为 int 或 (h, w) ) |
stride | 步长 |
padding | 填充 |
dilation | 膨胀卷积(稀疏卷积) |
bias | 是否使用偏置项(默认为 True ) |
torch.nn.MSELoss
均方误差损失
参数 | 解释 |
---|---|
reduction | 汇总方法 none:返回每个元素的损失值 mean:返回所有元素的损失的平均值(默认) sum:返回所有元素的损失值的和 |
torch.nn.CrossEntropyLoss
交叉熵损失
参数 | 解释 |
---|---|
reduction | 汇总方法 none:返回每个元素的损失值 mean:返回所有元素的损失的平均值(默认) sum:返回所有元素的损失值的和 |
torch.nn.init
:参数初始化:
torch.nn.init.normal_()
正态分布初始化
参数 | 解释 |
---|---|
param | 需要优化的参数 |
mean | 均值 |
std | 标准差 |
torch.nn.init.zeros_()
零初始化
参数 | 解释 |
---|---|
param | 需要优化的参数 |
torch.nn.init.xavier
torch.nn.init.xavier_uniform_()
Xavier均匀分布初始化
参数 | 解释 |
---|---|
param | 需要优化的参数 |
gain | 缩放因子,默认1 |
torch.nn.init.xavier_normal_()
Xavier正态分布初始化
参数 | 解释 |
---|---|
param | 需要优化的参数 |
gain | 缩放因子,默认1 |
torch.nn.init.constant_()
常数初始化
参数 | 解释 |
---|---|
param | 需要优化的参数 |
val | 常数 |
torch.nn.init.uniform_()
均匀分布初始化
参数 | 解释 |
---|---|
param | 需要优化的参数 |
a | 均匀分布的上界 |
b | 均匀分布的上界 |
torch.optim
torch.optim.SGD()
随机梯度下降优化器
输入
参数 | 解释 |
---|---|
params | 要优化的参数 |
lr | 学习率 |
momentum | 用于加速梯度下降,并避免在鞍点附近震荡。默认值为 0 |
dampening | 动量的衰减因子。默认值为 0 |
weight_decay | 权重衰减(L2 正则化),用于防止过拟合。默认值为 0 |
nesterov | 是否使用 Nesterov 动量,True 表示使用。默认值为 False |
返回
优化器对象
torch.utils
torch.utils.data
torch.utils.data.Dataloader()
批量加载数据,生成一个迭代器,随机抽取批量大小的数据
输入
参数 | 解释 |
---|---|
dataset | 数据集 |
batch_size=32 | 批次大小 |
shuffle=True | 是否打乱 |
num_workers=4 | 使用的进程数 |
返回
dataloader 迭代器
math
math.gamma
numpy
np.power
计算指数,可放入列表
np.arange
生成一维列表,含头不含尾
arange(start,stop,step,dtype)
参数 | 解释 |
---|---|
start | 开始 |
stop | 结束 |
step | 间隔 |
dtype | 数据类型 |