stn.pytorch, pytorch版本的空间变压器网络

分享于 

2分钟阅读

GitHub

  繁體 雙語
pytorch version of spatial transformer networks
  • 源代码名称:stn.pytorch
  • 源代码网址:http://www.github.com/fxia22/stn.pytorch
  • stn.pytorch源代码文档
  • stn.pytorch源代码下载
  • Git URL:
    git://www.github.com/fxia22/stn.pytorch.git
    Git Clone代码到本地:
    git clone http://www.github.com/fxia22/stn.pytorch
    Subversion代码到本地:
    $ svn co --depth empty http://www.github.com/fxia22/stn.pytorch
    Checked out revision 1.
    $ cd repo
    $ svn up trunk
    
    空间变压器网络版本

    根据pytorch教程从 https://github.com/qassemoquab/stnbhwd 移植。 支持CPU和 GPU。 要使用 ffi,你需要从pip安装 cffi 包。

    插件构建和测试
    
    cd script
    
    
    ./make.sh #build cuda code, don't forget to modify -arch argument for your GPU computational capacity version
    
    
    python build.py
    
    
    python test.py
    
    
    
    

    test_stn.ipynb 里有一个演示

    模块

    STN 是空间变压器模块,采用 B*H*W*D 张量和 B*H*W*2 网格作为输入,并进行双线性采样,实现了空间变换。

    AffineGridGen 采用 B*2*3 矩阵并生成仿射变换网格。

    CylinderGridGen 采用 B*1 θ矢量并生成转换栅格,沿着x 轴重新映射equirectangular图像。

    DenseAffineGridGen 采用 B*H*W*6 张量,对每个像素进行仿射变换。 卷积空间变压器的例子可以在 test_conv_stn.ipynb 中找到。

    在演示中可以找到一个简单的带有in损失的简单楞的损失功能的例子。

    火车
    • 设置学习率乘数,1e-3或者 1e-4可以正常工作。
    • 添加一个辅助损失来对仿射变换的差异进行正则化,以避免原始图像外部采样。
    复杂网格演示

    STN能够处理复杂的网格,但是如何参数化网格是一个问题。

    image


    相关文章