import tensorflow as tf
flags = tf.app.flags FLAGS = flags.FLAGS flags.DEFINE_integer('batch_size', 20, '') flags.DEFINE_float('salt_ratio', 0.01, '') flags.DEFINE_float('pepper_ratio', 0.1, '') flags.DEFINE_string('summary_dir', 'summary', '')
from tensorflow.examples.tutorials.mnist import input_data
def main(): with tf.variable_scope('input'): input = tf.placeholder(tf.float32, shape=[None, 28 * 28]) input_image = tf.reshape(input, [-1, 28, 28, 1])
with tf.variable_scope('random'): random_count = FLAGS.batch_size random_image = tf.random_uniform(shape=[random_count, 28, 28, 1], minval=0.0, maxval=1.0)
with tf.variable_scope('salt'): salt_ratio = tf.placeholder(tf.float32) salt_image = tf.to_float(tf.greater_equal(random_image, 1.0 - salt_ratio))
with tf.variable_scope('pepper'): pepper_ratio = tf.placeholder(tf.float32) pepper_image = tf.to_float(tf.greater_equal(random_image, pepper_ratio))
with tf.variable_scope('noised'): noised_image = tf.minimum(tf.maximum(input_image, salt_image), pepper_image)
mnist = input_data.read_data_sets("MNIST_DATA/")
with tf.Session() as sess: if tf.gfile.Exists(FLAGS.summary_dir): tf.gfile.DeleteRecursively(FLAGS.summary_dir)
names = ['0-input', '1-random', '2-salt', '3-pepper', '4-noised'] images = [input_image, random_image, salt_image, pepper_image, noised_image] summaries = [tf.image_summary('image', image, max_images=FLAGS.batch_size) for image in images]
batch = mnist.train.next_batch(FLAGS.batch_size) results = sess.run(summaries, feed_dict={input: batch[0], salt_ratio: FLAGS.salt_ratio, pepper_ratio: FLAGS.pepper_ratio}) for (result, name) in zip(results, names): writer = tf.train.SummaryWriter('%s/%s' % (FLAGS.summary_dir, name), sess.graph) writer.add_summary(result)
if __name__ == '__main__': main()
|