SOL4Py Sample: CIFARAutoEncoder
|
#******************************************************************************
#
# Copyright (c) 2018-2019 Antillia.com TOSHIYUKI ARAI. ALL RIGHTS RESERVED.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#******************************************************************************
# 2019/05/10
# CIFARAutoEncoderModel.py
# encodig: utf-8
import sys
import os
import time
import traceback
import matplotlib.pyplot as plt
import numpy as np
import keras
import tensorflow as tf
from keras.utils import np_utils
#from keras import backend as K
from keras.models import model_from_json
from keras.datasets import cifar10, cifar100
sys.path.append('../../')
from SOL4Py.ZMLModel import *
from SOL4Py.ZMain import *
from SOL4Py.keras.ZEpochChangeNotifier import *
from SOL4Py.keras.ZSimpleAutoEncoderModel import *
# cifar dataset id
CIFAR10 = 0
CIFAR100 = 1
class CIFARAutoEncoder(ZMLModel):
IMAGE_SIZE = 32
CHANNELS = 3
################
#Inner class
class CIFARAutoEncoderModel(ZSimpleAutoEncoderModel):
# Construcotr
def __init__(self, input_shape):
ZSimpleAutoEncoderModel.__init__(self, input_shape)
def encode(self, input_image):
x = Conv2D(64, (3, 3), activation='relu', padding='same')(input_image)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2, 2), padding='same')(x)
return encoded
def decode(self, encoded):
x = Conv2D(16, (3, 3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)
return decoded
##
# Constructor
def __init__(self, dataset_id = CIFAR10,
epochs=20, view=None, ipaddress="127.0.0.1", port=8888):
super(CIFARAutoEncoder, self).__init__(0, view)
self.input_shape =(self.IMAGE_SIZE, self.IMAGE_SIZE, self.CHANNELS)
self.epochs = epochs
self.dataset_id = dataset_id
# The following callbacks will be used to train AutoEncoderModel
self.callbacks = [ZEpochChangeNotifier(ipaddress, port, self.__class__.__name__+str("-") + str(self.dataset_id),
self.epochs+10)]
self.set_weight_filepath()
def build(self):
self._start(self.build.__name__)
try:
# 1 Load cifar dataset
self.load_dataset()
# 2 Create a cifar AutoEncoderModel
self.create()
if self.is_trained():
# 3 If our cifar model trained, i.e, if weight_filepath.h5 exists
self.load()
self.compile()
else:
self.compile()
self.train()
self.save()
except:
traceback.print_exc()
self._end(self.build.__name__)
def set_weight_filepath(self):
self._start(self.set_weight_filepath.__name__)
self.weight_filepath = self.__class__.__name__ + "_" + str(self.dataset_id) + ".h5"
self.write("weight_file " + self.weight_filepath)
self._end (self.set_weight_filepath.__name__)
def load_dataset(self):
self._start(self.load_dataset.__name__)
# We don't need labels y_train and y_test.
if (self.dataset_id == CIFAR10):
(x_train, _), (x_test, _) = cifar10.load_data()
if (self.dataset_id == CIFAR100):
(x_train, _), (x_test, _) = cifar100.load_data()
self.x_train = x_train.astype('float32') / 255.
self.x_test = x_test. astype('float32') / 255.
#self.x_train = np.reshape(x_train, (len(x_train),
# self.IMAGE_SIZE, self.IMAGE_SIZE, self.CHANNELS))
#self.x_test = np.reshape(x_test, (len(x_test),
# self.IMAGE_SIZE, self.IMAGE_SIZE, self.CHANNELS))
self._end (self.load_dataset.__name__)
def create(self):
self._start(self.create.__name__)
self.model = self.CIFARAutoEncoderModel(self.input_shape)
self._end (self.create.__name__)
def compile(self):
self._start(self.compile.__name__)
self.model.compile(optimizer='adadelta', loss='binary_crossentropy')
self._end (self.compile.__name__)
def train(self):
self._start(self.train.__name__)
start = time.time()
self.model.fit(self.x_train, self.x_train,
epochs= self.epochs,
batch_size=128,
shuffle=True,
verbose=True,
validation_data=(self.x_test, self.x_test),
callbacks = self.callbacks
)
elapsed_time = time.time() - start
elapsed = str("Train elapsed_time:{0}".format(elapsed_time) + "[sec]")
self.write(elapsed)
self._end(self.train.__name__)
def is_trained(self):
self._start(self.is_trained.__name__)
rc = False
if os.path.isfile(self.weight_filepath) == True:
print("weight filename {}".format(self.weight_filepath))
rc = True
self._end(self.is_trained.__name__)
return rc
def predict(self):
self._start(self.predict.__name__)
# Call self.model.predict method to get decoded_images from x_test image
self.decoded_images = self.model.predict(self.x_test)
self._end(self.predict.__name__)
def load(self):
self._start(self.load.__name__)
if os.path.isfile(self.weight_filepath) == True:
try:
self.model.load_weights(self.weight_filepath)
self.write("Loaded a weight file:{}".format(self.weight_filepath))
except:
self.write( formatted_traceback() )
else:
raise Exception("Not found weight file{: {}".format(self.weight_filepath))
self._end(self.load.__name__)
def show_images(self, n=10):
fig = plt.figure() #figsize=(20, 4))
for i in range(1, n+1):
# Display original x_test images
ax = plt.subplot(2, n, i)
plt.imshow(self.x_test[i].reshape(self.IMAGE_SIZE, self.IMAGE_SIZE, self.CHANNELS))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# Display decoded images predicted from original x_test images.
ax = plt.subplot(2, n, i + n)
plt.imshow(self.decoded_images[i].reshape(self.IMAGE_SIZE, self.IMAGE_SIZE, self.CHANNELS))
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
fig.tight_layout()
plt.show()
def save(self):
self._start(self.save.__name__)
try:
self.model.save_weights(self.weight_filepath)
self.write("Saved weight file {}".format(self.weight_filepath))
except:
print(formatted_traceback())
self._end(self.save.__name__)
#################################################
#
if main(__name__):
try:
app_name = os.path.basename(sys.argv[0])
epochs = 20
if len(sys.argv) ==2:
epochs = int(sys.argv[1])
model = CIFARAutoEncoder(dataset_id= CIFAR10,
epochs= epochs)
model.build()
model.predict()
model.show_images()
except:
traceback.print_exc()
Last modified:20 Sep. 2019