SOL4Py Sample: RoadSignsModel
|
#******************************************************************************
#
# Copyright (c) 2018-2019 Antillia.com TOSHIYUKI ARAI. ALL RIGHTS RESERVED.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
#******************************************************************************
# 2019/05/29
# RoadSignsModel.py
# encodig: utf-8
import sys
import os
import cv2
import time
import traceback
from keras.preprocessing.image import ImageDataGenerator
sys.path.append('../../')
from SOL4Py.ZMain import *
from SOL4Py.ZMLModel import *
from SOL4Py.keras.ZSimpleSequentialModel import *
from SOL4Py.keras.ZEpochChangeNotifier import *
############################################################
# Image Model class based on Keras ImageDataGenerator.
class RoadSignsModel(ZMLModel):
## Class Variables
IMAGE_MODEL = 0
IMAGE_WIDTH = 64
IMAGE_HEIGHT = 64
TRAIN_DATA_DIR = "./dataset/train"
VALID_DATA_DIR = "./dataset/valid"
##
# Constructor
def __init__(self, dataset_id, epochs=10, mainv=None, ipaddress="127.0.0.1", port=7777):
super(RoadSignsModel, self).__init__(dataset_id, mainv)
self._start(self.__init__.__name__)
self.write("dataset_id:{}, ephochs:{}, mainv:{}".format(dataset_id, epochs, mainv) )
self.model = None # Keras model
self.dataset_id = dataset_id
self.dataset = None
self.epochs = epochs
self.set_dataset_id(dataset_id)
self.callbacks = [ZEpochChangeNotifier(ipaddress, port, self.__class__.__name__, self.epochs+10)]
self.classes = sorted( os.listdir(self.TRAIN_DATA_DIR) )
self.save_class_names(self.classes) # 2019/09/20
self.n_classes = len(self.classes)
self.image_width = self.IMAGE_WIDTH
self.image_height = self.IMAGE_HEIGHT
self.IMAGE_SIZE = (self.image_width, self.image_height)
self._end(self.__init__.__name__)
def set_dataset_id(self, dataset_id):
self._start(self.set_dataset_id.__name__)
self.dataset_id = dataset_id
self.weight_file = self.__class__.__name__ + "_" + str(self.dataset_id) + ".h5"
self.write("weight_file " + self.weight_file)
self._end(self.set_dataset_id.__name__)
def build(self):
self.write("====================================")
self._start(self.build.__name__)
if self.is_trained() != True:
try:
# Create ImageDataGenerator
self.create_generator()
self.create_flow()
self.create()
self.compile()
self.train()
#self.evaluate()
self.save()
self.plot()
except:
traceback.print_exc()
self._end(self.build.__name__)
def create_generator(self):
self._start(self.create_generator.__name__)
self.train_data_generator = ImageDataGenerator(
rescale = 1.0/255.0,
rotation_range = 20,
width_shift_range = 1.0,
height_shift_range = 0.3,
shear_range = 0.4,
zoom_range = 0.3,
brightness_range = [0.7,1.2],
channel_shift_range= 2.0,
horizontal_flip = False,
vertical_flip = False)
self.valid_data_generator = ImageDataGenerator(
rescale = 1.0/255.0)
self._end(self.create_generator.__name__)
def create_flow(self):
self._start(self.create_flow.__name__)
self.BATCH_SIZE = 32
self.CLASS_MODE = "categorical"
self.COLOR_MODE = "rgb"
self.train_flow = self.train_data_generator.flow_from_directory(
self.TRAIN_DATA_DIR,
target_size = self.IMAGE_SIZE,
batch_size = self.BATCH_SIZE,
class_mode = self.CLASS_MODE,
color_mode = self.COLOR_MODE,
shuffle = True)
self.valid_flow = self.valid_data_generator.flow_from_directory(
self.VALID_DATA_DIR,
target_size = self.IMAGE_SIZE,
batch_size = self.BATCH_SIZE,
class_mode = self.CLASS_MODE,
color_mode = self.COLOR_MODE,
shuffle = True)
self._end(self.create_flow.__name__)
# Create a sequential model
def create(self):
self._start(self.create.__name__)
input_shape = (self.image_width, self.image_height, 3)
self.model = ZSimpleSequentialModel(input_shape, self.n_classes)
self._end(self.create.__name__)
def compile(self):
self._start(self.compile.__name__)
self.model.compile(optimizer='adam', loss='categorical_crossentropy', metrics = ['accuracy'])
self._end(self.compile.__name__)
def train(self, train_steps =100, valid_steps = 20):
self._start(self.train.__name__)
start = time.time()
self.train_steps = train_steps
self.valid_steps = valid_steps
self.model.fit_generator(
self.train_flow,
steps_per_epoch = self.train_steps,
epochs = self.epochs,
callbacks = self.callbacks,
validation_data = self.valid_flow,
validation_steps = self.valid_steps)
elapsed_time = time.time() - start
elapsed = str("Train elapsed_time:{0}".format(elapsed_time) + "[sec]")
self.write(elapsed)
self.model.summary()
self._end(self.train.__name__)
def predict(self, image):
prediction = self.model.predict(image)
return prediction
def save(self):
self._start(self.save.__name__)
self.model.save_weights(self.weight_file)
self.write("Saved weight file {}".format(self.weight_file))
self._end(self.save.__name__)
def load(self):
self._start(self.load.__name__)
try:
self.model.load_weights(self.weight_file)
self.write("Loaded a weight file:{}".format(self.weight_file))
except:
self.write( formatted_traceback() )
self._end(self.load.__name__)
def get_model(self):
return self.model
def is_trained(self):
rc = False
if os.path.isfile(self.weight_file) == True:
rc = True
return rc
def plot(self, filename=None):
from keras.utils import plot_model
if filename == None:
filename = self.__class__.__name__ + "_model.png"
plot_model(self.model, to_file=filename,show_shapes=True)
############################################################
#
if main(__name__):
try:
app_name = os.path.basename(sys.argv[0])
dataset_id = RoadSignsModel.IMAGE_MODEL
epochs = 20
if len(sys.argv) >= 2:
dataset_id = int(sys.argv[1])
if len(sys.argv) >= 3:
epochs = int(sys.argv[2])
model = RoadSignsModel(dataset_id, epochs, None)
model.build()
except:
traceback.print_exc()
Last modified:20 Sep. 2019