SOL4Py Sample: CIFARClassifier

SOL4Py Samples











#******************************************************************************
#
#  Copyright (c) 2018 Antillia.com TOSHIYUKI ARAI. ALL RIGHTS RESERVED.
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
#******************************************************************************

 
# 2018/09/20

# On CIFAR-10 dataset, see the following page:

# http://www.cs.toronto.edu/~kriz/cifar.html

# See also:
# https://github.com/ageron/tensorflow-models/blob/master/slim/datasets/download_and_convert_cifar.py

#  CIFARClassifier.py

# encodig: utf-8

import sys
import os
import time
import traceback
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
import numpy as np

from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPool2D
from keras.layers.core import Dense,Activation,Dropout,Flatten
from keras.datasets import cifar10
from keras.utils import np_utils
from keras.preprocessing.image import load_img, img_to_array


from PyQt5.QtCore    import *
from PyQt5.QtWidgets import *
from PyQt5.QtGui     import *
 
sys.path.append('../../')

from SOL4Py.ZApplicationView import *
from SOL4Py.ZLabeledComboBox import *
from SOL4Py.ZPushButton      import *
from SOL4Py.ZVerticalPane    import *
from SOL4Py.ZPILImageCropper import *
 
from SOL4Py.ZScrolledPlottingArea import *
from SOL4Py.ZScalableScrolledImageView import *
from SOL4Py.ZTabbedWindow import *

from CIFARModel import *

CIFAR10  = 0
CIFAR100 = 1

############################################################
# Classifier View

class MainView(ZApplicationView):  
  # Class variables

  # ClassifierView Constructor
  def __init__(self, title, x, y, width, height):
    super(MainView, self).__init__(title, x, y, width, height)
    self.font        = QFont("Arial", 10)
    self.setFont(self.font)
    
    self.model_loaded = False
    
    self.class_names_set = [None, None]
    
    # keras.preprocessing.image
    self.image       = None
                        
    # The names of the classes.
    # See https://github.com/ageron/tensorflow-models/blob/master/slim/datasets/download_and_convert_cifar.py
    
    self.class_names_set[CIFAR10] = [ 
                'airplane','automobile', 'bird',     'cat',    'deer',
                'dog',     'frog',       'horse',    'ship',   'truck',]

    self.class_names_set[CIFAR100] = [
                'apple',      'aquarium_fish', 'baby',       'bear',         'beaver',
                'bed',        'bee',           'beetle',     'bicycle',      'bottle',
                'bowl',       'boy',           'bridge',     'bus',          'butterfly',
                'camel',      'can',           'castle',     'caterpillar',  'cattle',
                'chair',      'chimpanzee',    'clock',      'cloud',        'cockroach',
                'couch',      'crab',          'crocodile',  'cup',          'dinosaur',
                'dolphin',    'elephant',      'flatfish',   'forest',       'fox',
                'girl',       'hamster',       'house',      'kangaroo',     'keyboard',
                'lamp',       'lawn_mower',    'leopard',    'lion',         'lizard',
                'lobster',    'man',           'maple_tree', 'motorcycle',   'mountain',
                'mouse',      'mushroom',      'oak_tree',   'orange',       'orchid',
                'otter',      'palm_tree',     'pear',       'pickup_truck', 'pine_tree',
                'plain',      'plate',         'poppy',      'porcupine',    'possum',
                'rabbit',     'raccoon',       'ray',        'road',         'rocket',
                'rose',       'sea',           'seal',       'shark',        'shrew',
                'skunk',      'skyscraper',    'snail',      'snake',        'spider',
                'squirrel',   'streetcar',     'sunflower',  'sweet_pepper', 'table',
                'tank',       'telephone',     'television', 'tiger',        'tractor',
                'train',      'trout',         'tulip',      'turtle',       'wardrobe',
                'whale',      'willow_tree',   'wolf',       'woman',        'worm',]

    # 1 Add a labeled combobox to top dock area
    self.add_datasets_combobox()
    
    # 2 Add a textedit to the left pane of the center area.
    self.text_editor = QTextEdit()
    self.text_editor.setLineWrapColumnOrWidth(600)
    self.text_editor.setLineWrapMode(QTextEdit.FixedPixelWidth)
    self.text_editor.setGeometry(0, 0, width/2, height)
    
    # 3 Add a tabbed_window the rigth pane of the center area.
    self.tabbed_window = ZTabbedWindow(self, 0, 0, width/2, height)
    
    # 3 Add a imageview to the tabbed_window.
    self.image_view = ZScalableScrolledImageView(self, 0, 0, width/3, height/3)   
    self.tabbed_window.add("SourceImage", self.image_view)
    
    # 3 Add a test_imageview to the right pane of the center area.
    self.test_image_view = ZScalableScrolledImageView(self, 0, 0, width/3, height/3)   
    self.tabbed_window.add("TestImage", self.test_image_view)

    self.add(self.text_editor)
    self.add(self.tabbed_window)
    
    # 4 Load trained model
    
    self.model = CIFARModel(self.dataset_id, mainv=self)
    if self.model.trained():
      self.model.load_dataset()
      self.model.load()
      self.model.compile()
      self.model.evaluate()
      self.model_loaded = True
    else:
      print("You have to create a model file and weight file")
      print("Please run: python CIFARModel.py " + str(self.dataset_id))
      QMessageBox.warning(self, "MNIST", 
           "Model/Weight File Missing.\nPlease run: python CIFARModel.py " + str(self.dataset_id))

    self.show()
    

  def add_datasets_combobox(self):
    self.dataset_id = CIFAR10
    self.datasets_combobox = ZLabeledComboBox(self, "Datasets", Qt.Horizontal)
    self.datasets_combobox.setFont(self.font)
    
    # We use the following datasets of sklearn to test XGBClassifier.
    self.datasets = {"CIFAR10": CIFAR10, "CIFAR100": CIFAR100}
    title = self.get_title()
    
    self.setWindowTitle(self.__class__.__name__ + " - " + title)
    
    self.datasets_combobox.add_items(self.datasets.keys())
    self.datasets_combobox.add_activated_callback(self.datasets_activated)
    self.datasets_combobox.set_current_text(self.dataset_id)

    self.classifier_button = ZPushButton("Classify", self)
    self.classifier_button.setEnabled(False)

    self.clear_button = ZPushButton("Clear", self)
    
    self.classifier_button.add_activated_callback(self.classifier_button_activated)
    self.clear_button.add_activated_callback(self.clear_button_activated)

    self.datasets_combobox.add(self.classifier_button)
    self.datasets_combobox.add(self.clear_button)
    
    self.set_top_dock(self.datasets_combobox)
     
     
  def write(self, text):
    self.text_editor.append(text)
    self.text_editor.repaint()

    
  def datasets_activated(self, text):
    self.dataset_id = self.datasets[text]
    title = self.get_title()
    self.setWindowTitle(text + " - " + title)

    self.classifier_button.setEnabled(False)
    
    self.model.set_dataset_id(self.dataset_id)
    if self.model.trained():
      self.model.load_dataset()
      self.model.load()
      self.model.compile()
      self.model.evaluate()
      self.model_loaded = True
      
    else:
      print("You have to create a model file and weight file")
      print("Run: python CIFARModel.py " + str(self.dataset_id))
      QMessageBox.warning(self, "CIFAR", 
           "Model/Weight File Missing.\nPlease run: python CIFARModel.py " + str(self.dataset_id))
 

  # Show FileOpenDialog and select an image file.
  def file_open(self):
    if self.model_loaded:
      options = QFileDialog.Options()
      filename, _ = QFileDialog.getOpenFileName(self,"FileOpenDialog", "",
                     "All Files (*);;Image Files (*.png;*jpg;*.jpeg)", options=options)
      if filename:
        self.load_file(filename)
      if self.image.all() != None:
        self.classifier_button.setEnabled(True)
    else:
      QMessageBox.warning(self, "CIFAR: Weight File Missing", 
           "Please run: python CIFARModel.py " + str(self.dataset_id))


  def load_file(self, filename):
    try:
      image_cropper = ZPILImageCropper(filename)
      cropped_file = "./~temp_cropped.png"
      # 1 Crop maximum square region from the filename and save it as a cropped_file.
      image_cropper.crop_maximum_square_region(cropped_file)
      
      # 2 Load an image from the cropped_fle.
      self.image_view.load_image(cropped_file) 
      self.set_filenamed_title(filename)
      
      # 3 Load an image from the cropped_file as Pillow image format.  
      self.image = load_img(cropped_file, target_size=(32, 32))
      
      # 4 Convert the self.image to numpy ndarray. 
      self.image = img_to_array(self.image)
      
      # 5 Set self.nadarryy to the test_image_view.
      self.test_image_view.set_image(self.image)

      # 6 Convert self.image in range[0-1.0]
      self.image = self.image.astype('float32')/255.0
      
      # 7 Expand the dimension of the self.image 
      self.image = np.expand_dims(self.image, axis=0) 
      
      #print(self.image.shape)
      os.remove(cropped_file)
      
    except:
      self.write(formatted_traceback())


  def classifier_button_activated(self, text):
    self.classifier_button.setEnabled(False)    
    self.clear_button.setEnabled(False)
    try:
      self.classify()
    except:
      self.write(formatted_traceback())
    self.classifier_button.setEnabled(True)
    self.clear_button.setEnabled(True)
 

  def classify(self):
    self.write("classify start")
    
    prediction = self.model.predict(self.image)
    
    pred = np.argmax(prediction, axis=1)
    self.write("Prediction: index " + str(pred))
    class_names = self.class_names_set[self.dataset_id]
        
    if pred >0 or pred <len(class_names):
      self.write("Prediction:" + class_names[int(pred)])
      
    self.write("classify end")
   
    
  def clear_button_activated(self, text):
    self.text_editor.setText("")
    pass
  
############################################################
#    
if main(__name__):

  try:
    app_name  = os.path.basename(sys.argv[0])
    applet    = QApplication(sys.argv)
  
    main_view = MainView(app_name, 40, 40, 900, 500)
    main_view.show ()

    applet.exec_()

  except:
    traceback.print_exc()
    

Last modified: 22 Sep. 2018