torchsample, 高级培训,数据增强和Pytorch工具

分享于 

7分钟阅读

GitHub

  繁體
Comprehensive Data Augmentation and Sampling for Pytorch
  • 源代码名称:torchsample
  • 源代码网址:http://www.github.com/ncullen93/torchsample
  • torchsample源代码文档
  • torchsample源代码下载
  • Git URL:
    git://www.github.com/ncullen93/torchsample.git
    Git Clone代码到本地:
    git clone http://www.github.com/ncullen93/torchsample
    Subversion代码到本地:
    $ svn co --depth empty http://www.github.com/ncullen93/torchsample
    Checked out revision 1.
    $ cd repo
    $ svn up trunk
    
    高级培训,数据增强和Pytorch实用程序

    v0.1.3 刚刚发布- 包含重大改进,Bug 修复和附加支持。 从发布中获取它,或者拉主分支。

    这里软件包提供了一些内容:

    • 用于keras的高级 MODULE,如带有回调,约束和regularizers的培训。
    • 全面的数据增强,转换,采样和加载
    • 效用张量和变量函数,这样你就不需要频繁的numpy

    有任何功能要求提交问题? ! 我会让它发生。特别是,任何数据增加,数据加载或者采样函数。

    要贡献请检查问题页面 tagged标记为 [contributions welcome]的那些。

    ModuleTrainer

    ModuleTrainer 类提供了高级训练接口,它在提供回调。约束。初始化。regularizers和更多的训练接口时抽象出训练循环。

    例如:

    from torchsample.modules import ModuleTrainer# Define your model EXACTLY as normalclassNetwork(nn.Module):
     def__init__(self):
     super(Network, self).__init__()
     self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
     self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
     self.fc1 = nn.Linear(1600, 128)
     self.fc2 = nn.Linear(128, 10)
     defforward(self, x):
     x = F.relu(F.max_pool2d(self.conv1(x), 2))
     x = F.relu(F.max_pool2d(self.conv2(x), 2))
     x = x.view(-1, 1600)
     x = F.relu(self.fc1(x))
     x = F.dropout(x, training=self.training)
     x =self.fc2(x)
     return F.log_softmax(x)
    model = Network()
    trainer = ModuleTrainer(model)
    trainer.compile(loss='nll_loss',
     optimizer='adadelta')
    trainer.fit(x_train, y_train, 
     val_data=(x_test, y_test),
     num_epoch=20, 
     batch_size=128,
     verbose=1)

    你还可以访问标准评估和预测功能:

    loss = model.evaluate(x_train, y_train)
    y_pred = model.predict(x_train)

    Torchsample提供了广泛的回调,通常模仿 Keras 中找到的接口:

    • EarlyStopping
    • ModelCheckpoint
    • LearningRateScheduler
    • ReduceLROnPlateau
    • CSVLogger
    from torchsample.callbacks import EarlyStopping
    callbacks = [EarlyStopping(monitor='val_loss', patience=5)]
    model.set_callbacks(callbacks)

    Torchsample还提供 regularizers:

    • L1Regularizer
    • L2Regularizer
    • L1L2Regularizer

    和约束:

    • UnitNorm
    • MaxNorm
    • NonNeg

    可以使用 正规表达式 和 module_filter 参数将regularizers和约束选择性地应用到层上。 约束可以是在任意批或者历元频率下应用的显式( 硬硬盘) 约束,或者它们可以是与regularizers类似的隐式( 柔和) 约束,后者将约束偏差添加到总模型损失。

    from torchsample.constraints import MaxNorm, NonNegfrom torchsample.regularizers import L1Regularizer# hard constraint applied every 5 batcheshard_constraint = MaxNorm(value=2., frequency=5, unit='batch', module_filter='*fc*')# implicit constraint added as a penalty term to model losssoft_constraint = NonNeg(lagrangian=True, scale=1e-3, module_filter='*fc*')
    constraints = [hard_constraint, soft_constraint]
    model.set_constraints(constraints)
    regularizers = [L1Regularizer(scale=1e-4, module_filter='*conv*')]
    model.set_regularizers(regularizers)

    你还可以直接安装在 torch.utils.data.DataLoader 上,并且可以设置验证集:

    from torchsample import TensorDatasetfrom torch.utils.data import DataLoader
    train_dataset = TensorDataset(x_train, y_train)
    train_loader = DataLoader(train_dataset, batch_size=32)
    val_dataset = TensorDataset(x_val, y_val)
    val_loader = DataLoader(val_dataset, batch_size=32)
    trainer.fit_loader(loader, val_loader=val_loader, num_epoch=100)

    命令行实用程序功能

    最后,torchsample提供了一些常用的实用工具函数:

    张量函数

    • th_iterproduct ( 模仿 itertools.product )
    • th_gather_nd ( torch.gather的n维版本)
    • th_random_choice ( 模仿 np.random.choice )
    • th_pearsonr ( 模仿 scipy.stats. pearsonr )
    • th_corrcoef ( 模仿 np.corrcoef )
    • th_affine2dth_affine3d ( torch.Tensors 上的仿射变换)

    变量函数

    • F_affine2dF_affine3d
    • F_map_coordinates2dF_map_coordinates3d

    数据增强和数据集

    torchsample包提供了大量的数据增强和转换工具,可以在数据加载过程中应用。 包还提供了灵活的TensorDatasetFolderDataset 类来处理大多数数据集需求。

    Torch 转换

    这些变换直接在 Torch 张量上

    • Compose()
    • AddChannel()
    • SwapDims()
    • RangeNormalize()
    • StdNormalize()
    • Slice2D()
    • RandomCrop()
    • SpecialCrop()
    • Pad()
    • RandomFlip()
    • ToTensor()

    仿射变换

    OriginalTransformed

    下面的变换在 Torch 张量上执行仿射( 或者仿射像) 变换。

    • Rotate()
    • Translate()
    • Shear()
    • Zoom()

    我们还提供了一个用于连接多个仿射变换的类,以便只进行一个插值:

    • Affine()
    • AffineCompose()

    数据集和采样

    提供以下数据集,提供用于在内存或者out-of-memory数据中进行抽样的一般结构和迭代器的一般结构和迭代器:

    • TensorDataset()

    • FolderDataset()

    确认

    感谢以下人员和贡献者:

    • 所有Keras贡献者
    • @deallynomore
    • @recastrodiaz

    数据  HIG  UTIL  UTI  Level  Utilities  
    相关文章