SOL4Py Sample: TorchMNISTClassifier

SOL4Py Samples













#******************************************************************************
#
#  Copyright (c) 2018-2019 Antillia.com TOSHIYUKI ARAI. ALL RIGHTS RESERVED.
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
#******************************************************************************

 
# 2019/09/10


#  TorchMNISTClassifier.py

# encodig: utf-8

import sys
import os
import time
import traceback

import numpy as np
from PIL import Image

sys.path.append('../../')

from SOL4Py.torch.ZTorchImagePreprocessor import ZTorchImagePreprocessor
from SOL4Py.ZTorchImageClassifierView import *

from TorchMNISTModel import TorchMNISTModel

MNIST         = 0
FASHION_MNIST = 1


############################################################
# Classifier View

class MainView(ZTorchImageClassifierView):  
  # Class variables

  # ClassifierView Constructor
  def __init__(self, title, x, y, width, height):
    super(MainView, self).__init__(title, x, y, width, height,
                  datasets = {"MNIST": MNIST, "FashionMNIST": FASHION_MNIST})
 
    self.class_names_set = [None, None]
    self.resize = 28
    self.crop   = 28    
    # ndarry image datat created from keras.preprocessing.image
    self.image       = None
                            
    self.class_names_set[MNIST] = ["0", "1", "2", "3", "4",
                                   "5", "6", "7", "8", "9"]
    self.class_names_set[FASHION_MNIST]  =['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

    self.image        = None
    self.model_loaded = False
  
    # Load trained model
    
    self.model = TorchMNISTModel(self.dataset_id, epochs = 10, mainv=self)
    if self.model.is_trained():
      self.model.load_dataset()
      self.model.create()
      self.model.load()    # Load a trained weight
      #self.model.evaluate()
      self.model_loaded = True
      
    else:
      print("You have to create a model file.")
      print("Run: python TorchMNISTModel.py " + str(self.dataset_id))
      QMessageBox.warning(self, "MNIST", 
           "Model/Weight File Missing.\nPlease run: python TorchMNISTModel.py " + str(self.dataset_id))
  
    self.show()


  def datasets_activated(self, text):
    self.dataset_id = self.datasets[text]
    title = self.get_title()
    self.setWindowTitle(text + " - " + title)
    self.model.set_dataset_id(self.dataset_id)
      
    self.classifier_button.setEnabled(False)
    
    if self.model.is_trained():
      self.model.load_dataset()
      self.model.load()
      #self.model.evaluate()
      self.model_loaded =True        
      QMessageBox.information(self, "TorchMNIST", 
           "OK: TorchMNIST Model files loaded")

    else:
      print("You have to create a model file and a weight file.")
      print("Please run: python TorchMNISTModel.py " + str(self.dataset_id))
      QMessageBox.warning(self, "MNIST", 
           "Model/Weight files missing.\nPlease run: python TorchMNISTModel.py " + str(self.dataset_id))


  def load_filexx(self, filename):
    self.ndarray = None
    resize = 28
    crop   = 28
    try:      
      # 1 Open an original image file by PIL Image class.
      self.image = Image.open(filename)
      # You have to convert the self.image to a gray scale image
      self.image = self.image.convert("L")
      
      self.image_view.set_image(np.array(self.image)) 
      self.resized_image = self.image.resize((resize, resize))

      preprocessor = ZTorchImagePreprocessor()
  
      self.set_filenamed_title(filename)
      
      # 2 Crop the image.  
      self.cropped_image = preprocessor.image_crop(self.image, resize, crop)
      
      # 3 Convert the self.image to numpy ndarray. 
      self.ndarray  = np.array(self.cropped_image)

      # 4 Set self.nadarryy to the test_image_view.
      self.test_image_view.set_image(self.ndarray)

    except:
      self.write(formatted_traceback())

  def load_file(self, filename):
    self.ndarray = None
    try:      
      # 1 Open an original image file by PIL Image class.
      self.image = Image.open(filename)
      # You have to convert the self.image to a gray scale image
      self.image = self.image.convert("L")
      
      self.image_view.set_image(np.array(self.image)) 
      self.resized_image = self.image.resize((self.resize, self.resize))

      preprocessor = ZTorchImagePreprocessor()
  
      self.set_filenamed_title(filename)
      
      # 2 Crop the image.  
      self.cropped_image = preprocessor.image_crop(self.image, self.resize, self.crop)
      
      # 3 Convert the self.image to numpy ndarray. 
      self.ndarray  = np.array(self.cropped_image)

      # 4 Set self.nadarryy to the test_image_view.
      self.test_image_view.set_image(self.ndarray)

    except:
      self.write(formatted_traceback())


  def classify(self):
    self.write("--------------------------------------------")
    self.write("classify start")
    self.write(self.filename)

    index = self.model.predict(self.cropped_image)
    classes = self.class_names_set[self.dataset_id]
    label   = classes[index]
    self.write("Prediction: {}".format(label) )
         
    self.write("classify end")



############################################################
#    
if main(__name__):

  try:
    app_name  = os.path.basename(sys.argv[0])
    applet    = QApplication(sys.argv)

    main_view = MainView(app_name, 40, 40, 900, 500)
    main_view.show ()

    applet.exec_()

  except:
    traceback.print_exc()
    

Last modified:20 Sep. 2019