SOL4Py Sample: MNISTClassifier

SOL4Py Samples











#******************************************************************************
#
#  Copyright (c) 2018 Antillia.com TOSHIYUKI ARAI. ALL RIGHTS RESERVED.
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
#******************************************************************************

 
# 2018/09/20


#  MNISTClassifier.py

# encodig: utf-8

import sys
import os
import time
import traceback
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

#from keras.models import Sequential
#from keras.layers.convolutional import Conv2D
#from keras.layers.pooling import MaxPool2D
#from keras.layers.core import Dense,Activation,Dropout,Flatten
#from keras.datasets import cifar10

from keras.utils import np_utils
from keras.preprocessing.image import load_img, img_to_array


from PyQt5.QtCore    import *
from PyQt5.QtWidgets import *
from PyQt5.QtGui     import *
 
sys.path.append('../../')

from SOL4Py.ZApplicationView import *
from SOL4Py.ZLabeledComboBox import *
from SOL4Py.ZPushButton      import *
from SOL4Py.ZVerticalPane    import *
from SOL4Py.ZPILImageCropper import *
 
from SOL4Py.ZScrolledPlottingArea import *
from SOL4Py.ZScalableScrolledImageView import *
from SOL4Py.ZTabbedWindow import *

from MNISTModel import *

MNIST         = 0
FASHION_MNIST = 1


############################################################
# Classifier View

class MainView(ZApplicationView):  
  # Class variables

  # ClassifierView Constructor
  def __init__(self, title, x, y, width, height):
    super(MainView, self).__init__(title, x, y, width, height)
    self.font        = QFont("Arial", 10)
    self.setFont(self.font)
 
    self.class_names_set = [None, None]
    
    # ndarry image datat created from keras.preprocessing.image
    self.image       = None
                            
    self.class_names_set[MNIST] = ["0", "1", "2", "3", "4",
                                   "5", "6", "7", "8", "9"]
    self.class_names_set[FASHION_MNIST]  =['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
               
    # 1 Add a labeled combobox to top dock area
    self.add_datasets_combobox()
    
    # 2 Add a textedit to the left pane of the center area.
    self.text_editor = QTextEdit()
    self.text_editor.setLineWrapColumnOrWidth(600)
    self.text_editor.setLineWrapMode(QTextEdit.FixedPixelWidth)
    self.text_editor.setGeometry(0, 0, width/2, height)
    
    # 3 Add a tabbed_window the rigth pane of the center area.
    self.tabbed_window = ZTabbedWindow(self, 0, 0, width/2, height)
    
    # 4 Add a imageview to the tabbed_window.
    self.image_view = ZScalableScrolledImageView(self, 0, 0, width/3, height/3)   
    self.tabbed_window.add("SourceImage", self.image_view)
    
    # 5 Add a test_imageview to the right pane of the center area.
    self.test_image_view = ZScalableScrolledImageView(self, 0, 0, width/3, height/3)   
    self.tabbed_window.add("TestImage", self.test_image_view)

    self.add(self.text_editor)
    self.add(self.tabbed_window)
 
    self.image        = None
    self.model_loaded = False
       
    # 6 Load trained model
    
    self.model = MNISTModel(self.dataset_id, epochs = 10, mainv=self)
    if self.model.trained():
      self.model.load_dataset()
      self.model.load()
      self.model.compile()
      self.model.evaluate()
      self.model_loaded = True
    else:
      print("You have to create a model file and weight file")
      print("Run: python MNISTModel.py " + str(self.dataset_id))
      QMessageBox.warning(self, "MNIST", 
           "Model/Weight File Missing.\nPlease run: python MNISTModel.py " + str(self.dataset_id))
  
    self.show()
    

  def add_datasets_combobox(self):
    self.dataset_id = MNIST
    self.datasets_combobox = ZLabeledComboBox(self, "Datasets", Qt.Horizontal)
    self.datasets_combobox.setFont(self.font)
    
    # We use the following datasets of sklearn to test XGBClassifier.
    self.datasets = {"MNIST": MNIST, "FashionMNIST": FASHION_MNIST}
    title = self.get_title()
    
    self.setWindowTitle(self.__class__.__name__ + " - " + title)
    
    self.datasets_combobox.add_items(self.datasets.keys())
    self.datasets_combobox.add_activated_callback(self.datasets_activated)
    self.datasets_combobox.set_current_text(self.dataset_id)

    self.classifier_button = ZPushButton("Classify", self)
    self.classifier_button.setEnabled(False)
    
    self.clear_button = ZPushButton("Clear", self)
    
    self.classifier_button.add_activated_callback(self.classifier_button_activated)
    self.clear_button.add_activated_callback(self.clear_button_activated)

    self.datasets_combobox.add(self.classifier_button)
    self.datasets_combobox.add(self.clear_button)
    
    self.set_top_dock(self.datasets_combobox)
     
     
  def write(self, text):
    self.text_editor.append(text)
    self.text_editor.repaint()

    
  def datasets_activated(self, text):
    self.dataset_id = self.datasets[text]
    title = self.get_title()
    self.setWindowTitle(text + " - " + title)
    self.model.set_dataset_id(self.dataset_id)
      
    self.classifier_button.setEnabled(False)
    
    if self.model.trained():
      self.model.load_dataset()
      self.model.load()
      self.model.compile()
      self.model.evaluate()
      self.model_loaded =True        
      QMessageBox.information(self, "MNIST", 
           "OK: MNIST Model/Weight files loaded")

    else:
      print("You have to create a model file and a weight file.")
      print("Please run: python MNISTModel.py " + str(self.dataset_id))
      QMessageBox.warning(self, "MNIST", 
           "Model/Weight files missing.\nPlease run: python MNISTModel.py " + str(self.dataset_id))
     

  # Show FileOpenDialog and select an image file.
  def file_open(self):
    if self.model_loaded:
      options = QFileDialog.Options()
      filename, _ = QFileDialog.getOpenFileName(self,"FileOpenDialog", "",
                     "All Files (*);;Image Files (*.png;*jpg;*.jpeg)", options=options)
      if filename:
        self.load_file(filename)
      if self.image.all() != None:
        self.classifier_button.setEnabled(True)
    else:
      QMessageBox.warning(self, "MNIST: Weight File Missing", 
           "Please run: python MNISTModel.py " + str(self.dataset_id))


  def load_file(self, filename):
    try:
      image_cropper = ZPILImageCropper(filename)
      cropped_file = "./~temp_cropped.png"
      # 1 Crop maximum square region from the filename and save it as a cropped_file.
      image_cropper.crop_maximum_square_region(cropped_file)
      
      # 2 Load an image from the cropped_fle.
      self.image_view.load_image(cropped_file) 
      self.set_filenamed_title(filename)
      
      # 3 Load an image from the cropped_file as Pillow image format.  
      self.image = load_img(cropped_file, grayscale=True, 
                                      target_size=(28, 28))
      
      # 4 Convert the self.image to numpy ndarray. 
      self.image = img_to_array(self.image)
      
      # 5 Set self.image to the test_image_view.
      self.test_image_view.set_image(self.image)

      self.image = self.image.reshape(28, 28, 1)
      
      # 6 Convert self.image in range[0-1.0]
      self.image = self.image.astype('float32')/255.0
      
      # 7 Expand the dimension of the self.image 
      self.image = np.expand_dims(self.image, axis=0) 
         
      self.write("Test image:{}".format(self.image.shape))
      os.remove(cropped_file)

    except:
      self.write(formatted_traceback())


  def classifier_button_activated(self, text):
    self.classifier_button.setEnabled(False)    
    self.clear_button.setEnabled(False)
    try:
      self.classify()
    except:
      self.write(formatted_traceback())
      
    self.classifier_button.setEnabled(True)
    self.clear_button.setEnabled(True)
 

  def classify(self):
    self.write("classify start")
    try:
      prediction = self.model.predict(self.image)
      #prob  = self.model.predict_prob(self.image)
      
      #self.write("Preds:{} ".format(preds))
      
      self.write("Prediction: " + str(prediction))
      
      pred = np.argmax(prediction, axis=1)
    
      self.write("Prediction: index " + str(pred))
      class_names = self.class_names_set[self.dataset_id]
        
      if pred >0 or pred <len(class_names):
        self.write("Prediction:" + class_names[int(pred)])
    
    except:
      self.write(formatted_traceback())
      
    self.write("classify end")
   
    
  def clear_button_activated(self, text):
    self.text_editor.setText("")
    pass
  

############################################################
#    
if main(__name__):

  try:
    app_name  = os.path.basename(sys.argv[0])
    applet    = QApplication(sys.argv)
  
    main_view = MainView(app_name, 40, 40, 900, 500)
    main_view.show ()

    applet.exec_()

  except:
    traceback.print_exc()
    

Last modified: 22 Sep. 2018