SOL4Py Sample: RandomForestClassifier

SOL4Py Samples





#******************************************************************************
#
#  Copyright (c) 2018 Antillia.com TOSHIYUKI ARAI. ALL RIGHTS RESERVED.
#
#    This program is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation, either version 3 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
#******************************************************************************

 
# 2018/09/01

#  RandomForestClassifier.py
# This is base on the following sample program:
# http://scikit-learn.org/stable/auto_examples/ensemble/plot_feature_transformation.html
#
# Author: Tim Head <betatim@gmail.com>
#
# License: BSD 3 clause

# See also:

# http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html
# https://github.com/scikit-learn/scikit-learn/blob/f0ab589f/sklearn/ensemble/forest.py#L1019

#
# encodig: utf-8

import sys
import os
import cv2
import time
import traceback
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
import numpy as np

import pickle
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
from sklearn.metrics import confusion_matrix, classification_report

from sklearn.model_selection import train_test_split

from sklearn import datasets

from PyQt5.QtCore    import *
from PyQt5.QtWidgets import *
from PyQt5.QtGui     import *
 
sys.path.append('../')

from SOL4Py.ZMLModel         import *
from SOL4Py.ZApplicationView import *
from SOL4Py.ZLabeledComboBox import *
from SOL4Py.ZPushButton      import *
from SOL4Py.ZVerticalPane    import * 
from SOL4Py.ZTabbedWindow    import *
from SOL4Py.ZScalableScrolledFigureView import *

Iris         = 0
Digits       = 1
Wine         = 2
BreastCancer = 3

############################################################
# Classifier Model clas

class RandomForestClassifierModel(ZMLModel):

  ##
  # Constructor
  def __init__(self, dataset_id, mainv):
    super(RandomForestClassifierModel, self).__init__(dataset_id, mainv)

    self.roc_label_position = 1
    
  def run(self):
    self.write("====================================")
    self._start(self.run.__name__)
    try:
      self.load_dataset()
      if self.trained():
        self.load()
      else:
        self.build()
        self.train()
        self.save()
        
      self.predict()
      self.visualize() 
      
    except:
      traceback.print_exc()
    self._end(self.run.__name__)
     
  def load_dataset(self):
    self._start(self.load_dataset.__name__)
    
    if self.dataset_id == Iris:
       self.dataset= datasets.load_iris()
       self.write("loaded iris dataset")

    if self.dataset_id == Digits:
       self.dataset= datasets.load_digits()
       self.write("loaded Digits dataset")
  
    if self.dataset_id == Wine:
       self.dataset= datasets.load_wine()
       self.write("loaded Wine dataset")

    if self.dataset_id == BreastCancer:
       self.dataset= datasets.load_breast_cancer()
       self.write("loaded BreastCancer dataset")
       
    attr = dir(self.dataset)
    self.write("dir:" + str(attr))
    if "feature_names" in attr:
      self.write("feature_names:" + str(self.dataset.feature_names))
    if "target_names" in attr:
      self.write("target_names:" + str(self.dataset.target_names))
 
    self.set_model_filename()
 
    self.view.description.setText(self.dataset.DESCR)
    
    X, y = self.dataset.data, self.dataset.target
    
    self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(X, y, test_size=0.3)

    self.X_train, self.X_train_lr, self.y_train, self.y_train_lr = train_test_split(self.X_train, self.y_train, test_size=0.3)
    self._end(self.load_dataset.__name__)
 
  def build(self):
    self._start(self.build.__name__)
    self.model = RandomForestClassifier(max_depth=10, n_estimators= 10)
    self.enc   = OneHotEncoder()
    self.lr    = LogisticRegression()

    self._end(self.run.__name__)

  def train(self):  
    self._start(self.train.__name__)
    start = time.time()
    # Class fit method of the classifier
    self.model.fit(self.X_train, self.y_train)
    
    self.enc.fit(self.model.apply(self.X_train))
    self.lr.fit(self.enc.transform(self.model.apply(self.X_train_lr)), self.y_train_lr)

    elapsed_time = time.time() - start
    elapsed = str("Train elapsed_time:{0}".format(elapsed_time) + "[sec]")
    self.write(elapsed)
    self._end(self.train.__name__)


  def predict(self):
    self._start(self.predict.__name__)
    
    # RandomForestClassifer.predict
    self.pred_test  = self.model.predict(self.X_test)
    
    # Call LogisticRegression.predic_prob
    self.y_pred = self.lr.predict_proba(self.enc.transform(self.model.apply(self.X_test)))[:, 1]
    self._end(self.predict.__name__)
    

  def visualize(self):
   cmatrix = confusion_matrix(self.y_test, self.pred_test)
   false_positive_rate, true_positive_rate, _ = roc_curve(self.y_test, self.y_pred, pos_label=self.roc_label_position)

   self.view.visualize(cmatrix, false_positive_rate, true_positive_rate, self.roc_label_position)
   
   
 
############################################################
# Classifier View

class MainView(ZApplicationView):  
  # Class variables

  # ClassifierView Constructor
  def __init__(self, title, x, y, width, height):
    super(MainView, self).__init__(title, x, y, width, height)
    self.font        = QFont("Arial", 10)
    self.setFont(self.font)
    
    # 1 Add a labeled combobox to top dock area
    self.add_datasets_combobox()
    
    # 2 Create a textedit to the left pane of the center area.
    self.text_editor = QTextEdit()
    self.text_editor.setLineWrapColumnOrWidth(600)
    self.text_editor.setLineWrapMode(QTextEdit.FixedPixelWidth)

    # 3 Create a description to display dataset.DESCR.
    self.description = QTextEdit()
    self.description.setLineWrapColumnOrWidth(600)
    self.description.setLineWrapMode(QTextEdit.FixedPixelWidth)
    
    # 4 Create a tabbed_window
    self.tabbed_window = ZTabbedWindow(self, 0, 0, width/2, height)

    # 5 Create a figure_view to the right pane of the center area.
    self.figure_view = ZScalableScrolledFigureView(self, 0, 0, width/2, height) 

    # 6 Create a roc_curve to the right pane of the center area.
    self.roc_curve = ZScalableScrolledFigureView(self, 0, 0, width/2, height) 
      
    self.add(self.text_editor)
    self.add(self.tabbed_window)
    
    self.tabbed_window.add("Description",  self.description)
    self.tabbed_window.add("ConfusionMatrix", self.figure_view)
    self.tabbed_window.add("ROC Curve", self.roc_curve)
    
    self.figure_view.hide()
    self.roc_curve.hide()
    
    self.show()
    
  def add_datasets_combobox(self):
    self.dataset_id = Iris
    self.datasets_combobox = ZLabeledComboBox(self, "Datasets", Qt.Horizontal)
    
    # We use the following datasets of sklearn to test RandomForestClassifier.
    self.datasets = {"Iris": Iris, "Digits": Digits, "Wine": Wine, "BreastCancer": BreastCancer}
    title = self.get_title()
    self.setWindowTitle( "Iris" + " - " + title)
    
    self.datasets_combobox.add_items(self.datasets.keys())
    self.datasets_combobox.add_activated_callback(self.datasets_activated)
    self.datasets_combobox.set_current_text(self.dataset_id)

    self.start_button = ZPushButton("Start", self)
    self.clear_button = ZPushButton("Clear", self)
    
    self.start_button.add_activated_callback(self.start_button_activated)
    self.clear_button.add_activated_callback(self.clear_button_activated)

    self.datasets_combobox.add(self.start_button)
    self.datasets_combobox.add(self.clear_button)
    
    self.set_top_dock(self.datasets_combobox)
  

  def write(self, text):
    self.text_editor.append(text)
    self.text_editor.repaint()
    
  def datasets_activated(self, text):
    self.dataset_id = self.datasets[text]
    title = self.get_title()
    self.setWindowTitle(text + " - " + title)

  def start_button_activated(self, text):
    self.model = RandomForestClassifierModel(self.dataset_id, self)
    self.start_button.setEnabled(False)    
    self.clear_button.setEnabled(False)
    try:
      self.model.run()
    except:
      pass
    self.start_button.setEnabled(True)
    self.clear_button.setEnabled(True)
 
    
  def clear_button_activated(self, text):
    self.text_editor.setText("")
    self.description.setText("")
    
    self.figure_view.hide()
    self.roc_curve.hide()
    
    plt.close()
    
  def visualize(self, cmatrix, false_positive_rate, true_positive_rate, label_position):
    self.figure_view.show()
    plt.close()

    sns.set()
    df = pd.DataFrame(cmatrix)
    sns.heatmap(df, annot=True, fmt="d")
    # Set a new figure to the figure_view.
    self.figure_view.set_figure(plt)
    plt.close()
    
    self.roc_curve.show()
    
    self.write("AUC :" + str( auc(false_positive_rate, true_positive_rate) ))
    
    plt.plot([0, 1], [0, 1], 'k--')
    plt.plot(false_positive_rate, true_positive_rate, label='RT + LR')

    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    plt.title('ROC curve: label_position=' + str(label_position))
    plt.legend(loc='best')
    
    self.roc_curve.set_figure(plt)
    plt.close()

  
############################################################
#    
if main(__name__):

  try:
    app_name  = os.path.basename(sys.argv[0])
    applet    = QApplication(sys.argv)
  
    main_view = MainView(app_name, 40, 40, 800, 500)
    main_view.show ()

    applet.exec_()

  except:
    traceback.print_exc()
    

Last modified: 6 May 2018