PyTorch下划线函数作用与总结

下划线函数

PyTorch中有一些功能拥有两个名字类似的函数,一个带下划线,一个不带下划线,以转置操作t为例,有t()t_()两个版本.
查看文档https://pytorch.org/docs/stable/torch.html#torch.thttps://pytorch.org/docs/stable/tensors.html#torch.Tensor.t_可知:

  • 不带下划线版本不改变对象本身,返回转置后的tensor对象
  • 带下划线版本是就地(in-place)操作,会改变当前对象本身,并且也会返回转置后的tensor对象

除了上面说的两个版本的对比,在torch.nn.init模块里也有许多带下划线的函数,但是在模块内并没有相应的无下划线函数.这些函数主要是用来初始化变量的,需要把变量传入函数,而不像上面所说作为成员函数调用.

1
2
3
w = torch.empty(3, 5)
nn.init.eye_(w) # 将w初始化为单位阵,w作为参数传入函数
# 该函数无下划线版本是 torch.eye(), 传入参数不同

总而言之,带下划线函数会改变对象本身或传入的对象,这点在用的时候需要注意.

下划线函数有哪些?

在PyTorch的doc中打开Index,使用正则表达式[a-z]+_\(\)过滤即可得到所有带下划线的函数.