简化基于 TensorFlow 的生成模型训练的框架
项目描述
简单GAN
简化生成模型训练的框架
SimpleGAN 是一个基于TensorFlow的框架,可以更轻松地训练生成模型。SimpleGAN 为用户提供具有可定制性选项的高级 API,允许他们用几行代码训练生成模型,或者用户可以重用现有架构中的模块来运行自定义训练循环和实验。
要求
确保您已安装以下软件包
安装
最新稳定版本:
$ pip install simplegan
最新开发版本:
$ pip install git+https://github.com/grohith327/simplegan.git
入门
DCGAN
from simplegan.gan import DCGAN
## initialize model
gan = DCGAN()
## load train data
train_ds = gan.load_data(use_mnist = True)
## get samples from the data object
samples = gan.get_sample(train_ds, n_samples = 5)
## train the model
gan.fit(train_ds = train_ds)
## get generated samples from model
generated_samples = gan.generate_samples(n_samples = 5)
GAN 的自定义训练循环
from simplegan.gan import Pix2Pix
## initialize model
gan = Pix2Pix()
## get generator module of Pix2Pix
generator = gan.generator() ## A tf.keras model
## get discriminator module of Pix2Pix
discriminator = gan.discriminator() ## A tf.keras model
## training loop
with tf.GradientTape() as tape:
""" Custom training loops """
卷积自动编码器
from simplegan.autoencoder import ConvolutionalAutoencoder
## initialize autoencoder
autoenc = ConvolutionalAutoencoder()
## load train and test data
train_ds, test_ds = autoenc.load_data(use_cifar10 = True)
## get sample from data object
train_sample = autoenc.get_sample(data = train_ds, n_samples = 5)
test_sample = autoenc.get_sample(data = test_ds, n_samples = 1)
## train the autoencoder
autoenc.fit(train_ds = train_ds, epochs = 5, optimizer = 'RMSprop', learning_rate = 0.002)
## get generated test samples from model
generated_samples = autoenc.generate_samples(test_ds = test_ds.take(1))
要详细查看更多示例,请查看此处
文档
查看文档页面
提供的模型
模型 | 生成的图像 |
---|---|
香草自动编码器 | 没有任何 |
卷积自动编码器 | |
变分自动编码器 [论文] | |
矢量量化 - 变分自动编码器 [论文] | |
香草 GAN [论文] | |
DCGAN [论文] | |
WGAN [论文] | |
CGAN [论文] | |
InfoGAN [论文] | |
Pix2Pix [纸] | |
CycleGAN [论文] | |
3DGAN(VoxelGAN) [论文] | |
自注意力GAN(SAGAN)[论文] |
贡献
我们感谢所有贡献。如果您计划执行错误修复、添加新功能或模型,请在提出拉取请求之前提交问题并进行讨论。
引文
@software{simplegan,
author = {{Rohith Gandhi et al.}},
title = {simplegan},
url = {https://simplegan.readthedocs.io},
version = {0.2.8},
}