决策树分类与回归模型的实现和可视化

Overview

DecisionTree

决策树分类与回归模型,以及可视化

ID3

ID3决策树是最朴素的决策树分类器:

  • 无剪枝
  • 只支持离散属性
  • 采用信息增益准则

data.py中,我们记录了一个小的西瓜数据集,用于离散属性的二分类任务。我们可以像下面这样训练一个ID3决策树分类器:

from ID3 import ID3Classifier
from data import load_watermelon2
import numpy as np

X, y = load_watermelon2(return_X_y=True) # 函数参数仿照sklearn.datasets
model = ID3Classifier()
model.fit(X, y)
pred = model.predict(X)
print(np.mean(pred == y))

输出1.0,说明我们生成的决策树是正确的。

C4.5

C4.5决策树分类器对ID3进行了改进:

  • 用信息增益率的启发式方法来选择划分特征;
  • 能够处理离散型和连续型的属性类型,即将连续型的属性进行离散化处理;
  • 剪枝;
  • 能够处理具有缺失属性值的训练数据;

我们实现了前两点,以及第三点中的预剪枝功能(超参数)

data.py中还有一个连续离散特征混合的西瓜数据集,我们用它来测试C4.5决策树的效果:

from C4_5 import C4_5Classifier
from data import load_watermelon3
import numpy as np

X, y = load_watermelon3(return_X_y=True) # 函数参数仿照sklearn.datasets
model = C4_5Classifier()
model.fit(X, y)
pred = model.predict(X)
print(np.mean(pred == y))

输出1.0,说明我们生成的决策树正确.

CART

分类

CART(Classification and Regression Tree)是C4.5决策树的扩展,支持分类和回归。CART分类树算法使用基尼系数选择特征,此外对于离散特征,CART决策树在每个节点二分划分,缓解了过拟合。

这里我们用sklearn中的鸢尾花数据集测试:

from CART import CARTClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

X, y = load_iris(return_X_y=True)
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)
model = CARTClassifier()
model.fit(train_X, train_y)
pred = model.predict(test_X)
print(accuracy_score(test_y, pred))

准确率95.55%。

回归

CARTRegressor类实现了决策树回归,以sklearn的波士顿数据集为例:

from CART import CARTRegressor
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

X, y = load_boston(return_X_y=True)
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)
model = CARTRegressor()
model.fit(train_X, train_y)
pred = model.predict(test_X)
print(mean_squared_error(test_y, pred))

输出26.352171052631576,sklearn决策树回归的Baseline是22.46,性能近似,说明我们的实现正确。

决策树绘制

分类树

利用python3的graphviz第三方库和Graphviz(需要安装),我们可以将决策树可视化:

from plot import tree_plot
from CART import CARTClassifier
from sklearn.datasets import load_iris

X, y = load_iris(return_X_y=True)
model = CARTClassifier()
model.fit(X, y)
tree_plot(model)

运行,文件夹中生成tree.png

iris_tree

如果提供了特征的名词和标签的名称,决策树会更明显:

from plot import tree_plot
from CART import CARTClassifier
from sklearn.datasets import load_iris

iris = load_iris()
model = CARTClassifier()
model.fit(iris.data, iris.target)
tree_plot(model,
          filename="tree2",
          feature_names=iris.feature_names,
          target_names=iris.target_names)

iris_tree2

绘制西瓜数据集2对应的ID3决策树:

from plot import tree_plot
from ID3 import ID3Classifier
from data import load_watermelon2

watermelon = load_watermelon2()
model = ID3Classifier()
model.fit(watermelon.data, watermelon.target)
tree_plot(
    model,
    filename="tree",
    font="SimHei",
    feature_names=watermelon.feature_names,
    target_names=watermelon.target_names,
)

这里要自定义字体,否则无法显示中文:

watermelon

回归树

用同样的方法,我们可以进行回归树的绘制:

from plot import tree_plot
from ID3 import ID3Classifier
from sklearn.datasets import load_boston

boston = load_boston()
model = ID3Classifier(max_depth=5)
model.fit(boston.data, boston.target)
tree_plot(
    model,
    feature_names=boston.feature_names,
)

由于生成的回归树很大,我们限制最大深度再绘制:

regression

调参

CART和C4.5都是有超参数的,我们让它们作为sklearn.base.BaseEstimator的派生类,借助sklearn的GridSearchCV,就可以实现调参:

from plot import tree_plot
from CART import CARTClassifier
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split, GridSearchCV

wine = load_wine()
train_X, test_X, train_y, test_y = train_test_split(
    wine.data,
    wine.target,
    train_size=0.7,
)
model = CARTClassifier()
grid_param = {
    'max_depth': [2, 4, 6, 8, 10],
    'min_samples_leaf': [1, 3, 5, 7],
}

search = GridSearchCV(model, grid_param, n_jobs=4, verbose=5)
search.fit(train_X, train_y)
best_model = search.best_estimator_
print(search.best_params_, search.best_estimator_.score(test_X, test_y))
tree_plot(
    best_model,
    feature_names=wine.feature_names,
    target_names=wine.target_names,
)

输出最优参数和最优模型在测试集上的表现:

{'max_depth': 4, 'min_samples_leaf': 3} 0.8518518518518519

绘制对应的决策树:

wine

剪枝

在ID3和CART回归中加入了REP剪枝,C4.5则支持了PEP剪枝。

对IRIS数据集训练后的决策树进行PEP剪枝:

iris = load_iris()
model = C4_5Classifier()
X, y = iris.data, iris.target
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7)
model.fit(train_X, train_y)
print(model.score(test_X, test_y))
tree_plot(model,
          filename="src/pre_prune",
          feature_names=iris.feature_names,
          target_names=iris.target_names)
model.pep_pruning()
print(model.score(test_X, test_y))
tree_plot(model,
          filename="src/post_prune",
          feature_names=iris.feature_names,
          target_names=iris.target_names,
)

剪枝前后的准确率分别为97.78%,100%,即泛化性能的提升:

prepre

Owner
Welt Xing
Undergraduate in AI school, Nanjing University. Main interest(for now): Machine learning and deep learning.
Welt Xing
ThunderSVM: A Fast SVM Library on GPUs and CPUs

What's new We have recently released ThunderGBM, a fast GBDT and Random Forest library on GPUs. add scikit-learn interface, see here Overview The miss

Xtra Computing Group 1.4k Dec 22, 2022
100 Days of Machine and Deep Learning Code

💯 Days of Machine Learning and Deep Learning Code MACHINE LEARNING TOPICS COVERED - FROM SCRATCH Linear Regression Logistic Regression K Means Cluste

Tanishq Gautam 66 Nov 02, 2022
distfit - Probability density fitting

Python package for probability density function fitting of univariate distributions of non-censored data

Erdogan Taskesen 187 Dec 30, 2022
A Python implementation of GRAIL, a generic framework to learn compact time series representations.

GRAIL A Python implementation of GRAIL, a generic framework to learn compact time series representations. Requirements Python 3.6+ numpy scipy tslearn

3 Nov 24, 2021
A Python toolkit for rule-based/unsupervised anomaly detection in time series

Anomaly Detection Toolkit (ADTK) Anomaly Detection Toolkit (ADTK) is a Python package for unsupervised / rule-based time series anomaly detection. As

Arundo Analytics 888 Dec 30, 2022
A machine learning toolkit dedicated to time-series data

tslearn The machine learning toolkit for time series analysis in Python Section Description Installation Installing the dependencies and tslearn Getti

2.3k Jan 05, 2023
Retrieve annotated intron sequences and classify them as minor (U12-type) or major (U2-type)

(intron I nterrogator and C lassifier) intronIC is a program that can be used to classify intron sequences as minor (U12-type) or major (U2-type), usi

Graham Larue 4 Jul 26, 2022
Lingtrain Alignment Studio is an ML based app for texts alignment on different languages.

Lingtrain Alignment Studio Intro Lingtrain Alignment Studio is the ML based app for accurate texts alignment on different languages. Extracts parallel

Sergei Averkiev 186 Jan 03, 2023
Python library which makes it possible to dynamically mask/anonymize data using JSON string or python dict rules in a PySpark environment.

pyspark-anonymizer Python library which makes it possible to dynamically mask/anonymize data using JSON string or python dict rules in a PySpark envir

6 Jun 30, 2022
stability-selection - A scikit-learn compatible implementation of stability selection

stability-selection - A scikit-learn compatible implementation of stability selection stability-selection is a Python implementation of the stability

185 Dec 03, 2022
XManager: A framework for managing machine learning experiments 🧑‍🔬

XManager is a platform for packaging, running and keeping track of machine learning experiments. It currently enables one to launch experiments locally or on Google Cloud Platform (GCP). Interaction

DeepMind 620 Dec 27, 2022
ELI5 is a Python package which helps to debug machine learning classifiers and explain their predictions

A library for debugging/inspecting machine learning classifiers and explaining their predictions

154 Dec 17, 2022
Visualize classified time series data with interactive Sankey plots in Google Earth Engine

sankee Visualize changes in classified time series data with interactive Sankey plots in Google Earth Engine Contents Description Installation Using P

Aaron Zuspan 76 Dec 15, 2022
Basic Docker Compose for Machine Learning Purposes

Docker-compose for Machine Learning How to use: cd docker-ml-jupyterlab

Chris Chen 1 Oct 29, 2021
We have a dataset of user performances. The project is to develop a machine learning model that will predict the salaries of baseball players.

Salary-Prediction-with-Machine-Learning 1. Business Problem Can a machine learning project be implemented to estimate the salaries of baseball players

Ayşe Nur Türkaslan 9 Oct 14, 2022
2021 Machine Learning Security Evasion Competition

2021 Machine Learning Security Evasion Competition This repository contains code samples for the 2021 Machine Learning Security Evasion Competition. P

Fabrício Ceschin 8 May 01, 2022
Repository for DCA0305, an undergraduate course about Machine Learning Workflows and Pipelines

Federal University of Rio Grande do Norte Technology Center Department of Computer Engineering and Automation Machine Learning Based Systems Design Re

Ivanovitch Silva 81 Oct 18, 2022
机器学习检测webshell

ai-webshell-detect 机器学习检测webshell,利用textcnn+简单二分类网络,基于keras,花了七天 检测原理: 从文件熵 文件长度 文件语句提取出特征,然后文件熵与长度送入二分类网络,文件语句送入textcnn 项目原理,介绍,怎么做出来的

Huoji's 56 Dec 14, 2022
Scikit-Learn useful pre-defined Pipelines Hub

Scikit-Pipes Scikit-Learn useful pre-defined Pipelines Hub Usage: Install scikit-pipes It's advised to install sklearn-genetic using a virtual env, in

Rodrigo Arenas 1 Apr 26, 2022
Temporal Alignment Prediction for Supervised Representation Learning and Few-Shot Sequence Classification

Temporal Alignment Prediction for Supervised Representation Learning and Few-Shot Sequence Classification Introduction. This package includes the pyth

5 Dec 06, 2022