Binarygan源码分析
https://github.com/salu133445/binarygan
数据集
使用sharedarray生成数据集,通过
|
|
其实是将文件存放在了/dev/shm里,默认文件为/dev/shm/_binarized_mnist_x
config开发模式
配置项都存放在config.py中了。
版本升级tf1.0->2.0
将import tensorflow as tf替换为
import tensorflow.compat.v1 as tf tf.disable_v2_behavior()
即可
model: binarygan, gan gan_type: gan, wgan, wgan-gp
ld
apt install cuda-nvrtc-dev-10-2 libnvinfer-plugin-dev
tfds
data, info = tfds.load(‘datasetn_name’,with_info) 注意data中的格式完全由info指定,所以info必看。 data: dict info: DatasetInfo
data字典格式由info的splits字段定义,如训练集和测试集(train/test)
train: Dataset/DatasetV1Adapter test: Dataset/DatasetV1Adapter
数据集格式由info的features字段定义,如image和label,如此才得能得到x,y数据。
数据在某些维度上可能是1,需要处理: tf.squeeze() np.squeeze()
数据的类型也可能要转换 tf.cast() np.narray.astype()
- 原文作者:mlyixi
- 原文链接:https://mlyixi.github.io/post/ml/binarygan%E6%BA%90%E7%A0%81%E5%88%86%E6%9E%90/
- 版权声明:本作品采用知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议进行许可,非商业转载请注明出处(作者,原文链接),商业转载请联系作者获得授权。