Skip to main content

简化基于 TensorFlow 的生成模型训练的框架

项目描述

简单GAN

执照 文件状态 下载 下载 代码风格:黑色

简化生成模型训练的框架

SimpleGAN 是一个基于TensorFlow的框架,可以更轻松地训练生成模型。SimpleGAN 为用户提供具有可定制性选项的高级 API,允许他们用几行代码训练生成模型,或者用户可以重用现有架构中的模块来运行自定义训练循环和实验。

要求

确保您已安装以下软件包

安装

最新稳定版本:

  $ pip install simplegan

最新开发版本:

  $ pip install git+https://github.com/grohith327/simplegan.git

入门

DCGAN
from simplegan.gan import DCGAN

## initialize model
gan = DCGAN() 

## load train data
train_ds = gan.load_data(use_mnist = True)

## get samples from the data object
samples = gan.get_sample(train_ds, n_samples = 5)

## train the model
gan.fit(train_ds = train_ds)

## get generated samples from model
generated_samples = gan.generate_samples(n_samples = 5)
GAN 的自定义训练循环
from simplegan.gan import Pix2Pix

## initialize model
gan = Pix2Pix()

## get generator module of Pix2Pix
generator = gan.generator() ## A tf.keras model

## get discriminator module of Pix2Pix
discriminator = gan.discriminator() ## A tf.keras model

## training loop
with tf.GradientTape() as tape:
""" Custom training loops """
卷积自动编码器
from simplegan.autoencoder import ConvolutionalAutoencoder

## initialize autoencoder
autoenc = ConvolutionalAutoencoder()

## load train and test data
train_ds, test_ds = autoenc.load_data(use_cifar10 = True)

## get sample from data object
train_sample = autoenc.get_sample(data = train_ds, n_samples = 5)
test_sample = autoenc.get_sample(data = test_ds, n_samples = 1)

## train the autoencoder
autoenc.fit(train_ds = train_ds, epochs = 5, optimizer = 'RMSprop', learning_rate = 0.002)

## get generated test samples from model
generated_samples = autoenc.generate_samples(test_ds = test_ds.take(1))

要详细查看更多示例,请查看此处

文档

查看文档页面

提供的模型

模型 生成的图像
香草自动编码器 没有任何
卷积自动编码器
变分自动编码器 [论文]
矢量量化 - 变分自动编码器 [论文]
香草 GAN [论文]
DCGAN [论文]
WGAN [论文]
CGAN [论文]
InfoGAN [论文]
Pix2Pix []
CycleGAN [论文]
3DGAN(VoxelGAN) [论文]
自注意力GAN(SAGAN)[论文]

贡献

我们感谢所有贡献。如果您计划执行错误修复、添加新功能或模型,请在提出拉取请求之前提交问题并进行讨论。

引文

@software{simplegan,
    author = {{Rohith Gandhi et al.}},
    title = {simplegan},
    url = {https://simplegan.readthedocs.io},
    version = {0.2.8},
}

贡献者

项目详情


下载文件

下载适用于您平台的文件。如果您不确定要选择哪个,请了解有关安装包的更多信息。

源分布

simplegan-0.2.9.tar.gz (33.4 kB 查看哈希

已上传 source