SOL4Py Sample: TorchCIFARClassifier

SOL4Py Samples















#******************************************************************************
#
#  Copyright (c) 2018-2019 Antillia.com TOSHIYUKI ARAI. ALL RIGHTS RESERVED.
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
#******************************************************************************

 
# 2019/06/30
# 2019/09/18

# On CIFAR-10 dataset, see the following page:

# http://www.cs.toronto.edu/~kriz/cifar.html

# See: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#loading-and-normalizing-cifar10
#
#  TorchCIFARClassifier.py

# encodig: utf-8

import sys
import os
import time
import traceback
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
import numpy as np


sys.path.append('../../')

from SOL4Py.ZTorchImageClassifierView import *
from SOL4Py.torch.ZTorchImagePreprocessor import ZTorchImagePreprocessor

from TorchCIFARModel import *

CIFAR10  = 0
CIFAR100 = 1

############################################################
# Classifier View

class MainView(ZTorchImageClassifierView):  
  # Class variables

  # ClassifierView Constructor
  def __init__(self, title, x, y, width, height):
    super(MainView, self).__init__(title, x, y, width, height,
                        datasets = {"CIFAR10": CIFAR10, "CIFAR100": CIFAR100})
                            
    self.model_loaded = False
    
    self.class_names_set = [None, None]
    
    self.resize = 32
    self.crop   = 32
    self.image       = None
                        
    # The names of the classes.
    # See https://github.com/ageron/tensorflow-models/blob/master/slim/datasets/download_and_convert_cifar.py
    
    self.class_names_set[CIFAR10] = [ 
                'airplane','automobile', 'bird',     'cat',    'deer',
                'dog',     'frog',       'horse',    'ship',   'truck',]

    self.class_names_set[CIFAR100] = [
                'apple',      'aquarium_fish', 'baby',       'bear',         'beaver',
                'bed',        'bee',           'beetle',     'bicycle',      'bottle',
                'bowl',       'boy',           'bridge',     'bus',          'butterfly',
                'camel',      'can',           'castle',     'caterpillar',  'cattle',
                'chair',      'chimpanzee',    'clock',      'cloud',        'cockroach',
                'couch',      'crab',          'crocodile',  'cup',          'dinosaur',
                'dolphin',    'elephant',      'flatfish',   'forest',       'fox',
                'girl',       'hamster',       'house',      'kangaroo',     'keyboard',
                'lamp',       'lawn_mower',    'leopard',    'lion',         'lizard',
                'lobster',    'man',           'maple_tree', 'motorcycle',   'mountain',
                'mouse',      'mushroom',      'oak_tree',   'orange',       'orchid',
                'otter',      'palm_tree',     'pear',       'pickup_truck', 'pine_tree',
                'plain',      'plate',         'poppy',      'porcupine',    'possum',
                'rabbit',     'raccoon',       'ray',        'road',         'rocket',
                'rose',       'sea',           'seal',       'shark',        'shrew',
                'skunk',      'skyscraper',    'snail',      'snake',        'spider',
                'squirrel',   'streetcar',     'sunflower',  'sweet_pepper', 'table',
                'tank',       'telephone',     'television', 'tiger',        'tractor',
                'train',      'trout',         'tulip',      'turtle',       'wardrobe',
                'whale',      'willow_tree',   'wolf',       'woman',        'worm',]

    
    #  Load trained model
    
    self.model = TorchCIFARModel(self.dataset_id, mainv=self)
    
    if self.model.is_trained():
      self.model.load_dataset()
      self.model.create()
      self.model.load()    # Load a trained weight
      self.model.evaluate()
      self.model_loaded = True
    else:
      print("You have to create a model file")
      print("Please run: python CIFARModel.py " + str(self.dataset_id))
      QMessageBox.warning(self, "MNIST", 
           "Mode file is missing.\nPlease run: python CIFARModel.py " + str(self.dataset_id))

    self.show()



  def datasets_activated(self, text):
    self.dataset_id = self.datasets[text]
    title = self.get_title()
    self.setWindowTitle(text + " - " + title)

    self.classifier_button.setEnabled(False)
    
    self.model.set_dataset_id(self.dataset_id)
    self.model.load_dataset()
    
    if self.model.is_trained():
    
      self.model.load()
      #self.model.evaluate()
      self.model_loaded = True
      
    else:
      self.model.build()
      
      print("You have to create a model file and weight file")
      print("Run: python TorchCIFARModel.py " + str(self.dataset_id))
      QMessageBox.warning(self, "CIFAR", 
           "Model/Weight File Missing.\nPlease run: python TorchCIFARModel.py " + str(self.dataset_id))
 

  def classify(self):
    self.write("--------------------------------------------")
    self.write("classify start.")
    self.write(self.filename)
    input = Variable(self.image_tensor)
    index = self.model.predict(input)
    classes = self.class_names_set[self.dataset_id]
    label   = classes[index]
    self.write("Prediction: {}".format(label) )

    self.write("classify end.")


############################################################
#    
if main(__name__):

  try:
    app_name  = os.path.basename(sys.argv[0])
    applet    = QApplication(sys.argv)
  
    main_view = MainView(app_name, 40, 40, 900, 500)
    main_view.show ()

    applet.exec_()

  except:
    traceback.print_exc()
    

Last modified:20 Sep. 2019