SOL4Py Sample: MNISTAutoEncoder
|
#******************************************************************************
#
# Copyright (c) 2018-2019 Antillia.com TOSHIYUKI ARAI. ALL RIGHTS RESERVED.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#******************************************************************************
# 2019/05/10
# MNISTAutoEncoderModel.py
# This is based on the following sample program.
# https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py
# See also https://keras.io/datasets/
# encodig: utf-8
import sys
import os
import time
import traceback
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import keras
import tensorflow as tf
from keras.utils import np_utils
from keras.datasets import mnist
from keras import backend as K
from keras.models import model_from_json
import numpy as np
sys.path.append('../../')
from SOL4Py.ZMLModel import *
from SOL4Py.ZMain import *
from SOL4Py.keras.ZEpochChangeNotifier import *
from SOL4Py.keras.ZSimpleAutoEncoderModel import *
class MNISTAutoEncoder(ZMLModel):
IMAGE_SIZE = 28
CHANNELS = 1 #1: Gray scale
##
# Constructor
def __init__(self, epochs, mainv=None, ipaddress="127.0.0.1", port=8888):
super(MNISTAutoEncoder, self).__init__(0, mainv)
self.input_shape =(self.IMAGE_SIZE, self.IMAGE_SIZE, self.CHANNELS)
self.epochs = epochs
self.dataset_id = 0
self.callbacks = [ZEpochChangeNotifier(ipaddress, port, self.__class__.__name__+str("-") + str(self.dataset_id),
self.epochs+10)]
self.set_weight_filepath()
def build(self):
self._start(self.build.__name__)
try:
# 1. Load MNIST dataset
self.load_dataset()
# 2. Cread MNISTAutoEncoderModel
self.create()
if self.is_trained():
# 3. If our model already trained, load the model's weight file.
self.load()
# 4. Compile
self.compile()
else:
# 5. If not trained
self.compile()
# 6. Train MNISTAutoEncoderModel
self.train()
# 7. Save the weight file
self.save()
except:
traceback.print_exc()
self._end(self.build.__name__)
def set_weight_filepath(self):
self._start(self.set_weight_filepath.__name__)
weight_file = self.__class__.__name__ + "_" + str(self.dataset_id) + ".h5"
current_dir = os.path.dirname(os.path.abspath(__file__))
self.weight_filepath = os.path.join(current_dir, weight_file)
self.write("WeightFilePath " + self.weight_filepath)
self._end(self.set_weight_filepath.__name__)
def load_dataset(self):
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test. astype('float32') / 255.
self.x_train = np.reshape(x_train, (len(x_train),
self.IMAGE_SIZE, self.IMAGE_SIZE, self.CHANNELS))
self.x_test = np.reshape(x_test, (len(x_test),
self.IMAGE_SIZE, self.IMAGE_SIZE, self.CHANNELS))
def create(self):
self._start(self.create.__name__)
self.model = ZSimpleAutoEncoderModel(self.input_shape)
self._end(self.create.__name__)
def compile(self):
self._start(self.compile.__name__)
self.model.compile(optimizer='adadelta', loss='binary_crossentropy')
self._end(self.compile.__name__)
def train(self):
self._start(self.train.__name__)
start = time.time()
self.model.fit(self.x_train, self.x_train,
epochs= self.epochs,
batch_size=128,
shuffle=True,
verbose=True,
validation_data=(self.x_test, self.x_test),
callbacks = self.callbacks
)
elapsed_time = time.time() - start
elapsed = str("Train elapsed_time:{0}".format(elapsed_time) + "[sec]")
self.write(elapsed)
self._end(self.train.__name__)
def is_trained(self):
self._start(self.trained.__name__)
rc = False
# Does the weight_file exist?
if os.path.isfile(self.weight_filepath) == True:
rc = True
self._end(self.trained.__name__)
return rc
def predict(self):
# Call self.model.predict method to get decoded_images from x_test image
self.decoded_images = self.model.predict(self.x_test)
def load(self):
self._start(self.load.__name__)
try:
self.model.load_weights(self.weight_filepath)
self.write("Loaded a weight file:{}".format(self.weight_filepath))
except:
self.write( formatted_traceback() )
self._end(self.load.__name__)
def show_images(self, n=10):
fig = plt.figure(figsize=(20, 4))
for i in range(1, n+1):
# Display original x_test images
ax = plt.subplot(2, n, i)
plt.imshow(self.x_test[i].reshape(self.IMAGE_SIZE, self.IMAGE_SIZE))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# Display decoded images predicted from original x_test images.
ax = plt.subplot(2, n, i + n)
plt.imshow(self.decoded_images[i].reshape(self.IMAGE_SIZE, self.IMAGE_SIZE))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.tight_layout()
plt.show()
def save(self):
self._start(self.save.__name__)
try:
self.model.save_weights(self.weight_filepath)
self.write("Saved weight file {}".format(self.weight_filepath))
except:
print(formatted_traceback())
self._end(self.save.__name__)
#################################################
#
if main(__name__):
try:
app_name = os.path.basename(sys.argv[0])
epochs = 10
if len(sys.argv) == 2:
epochs = int(sys.argv[1])
model = MNISTAutoEncoder(epochs)
model.build()
model.predict()
model.show_images()
except:
traceback.print_exc()
Last modified:20 Sep. 2019