机器学习-08-AdaBoosting-集成器

学习来源:日撸 Java 三百行(61-70天,决策树与集成学习)_闵帆的博客-CSDN博客

AdaBoosting

算法训练过程

1.初始化树桩分类器的个数,在这里为100个。

2.从第一个树桩分类器开始进行循环。如果为第一个树桩,各数据的权重进行初始化,即每一个数据的权重为1/个数。如果是其他树桩,调整每个数据的权重(AdaBoosting带权数据集:AdaBootsting第一个代码中的调整权重)。

3.为当前树桩分类器进行训练(AdaBoosting-树桩分类器中的训练)。

4.计算错误分类的数据的权重和(error)。以0.5 * \log(\frac{1}{error - 1})作为当前树桩分类器的权重。

5.如果当前树桩分类器的正确率达到0.99999以上,则认为各权重已经拟合,就退出循环。否则进行下一个循环。

对一个数据分类结果的判断

1.对已有树桩分类器的个数:按AdaBoosting-树桩分类器中的分类进行获取当前数据的分类结果。

2.在分类结果下对应的类别权重和(tempLabelsCountArray)加上该树桩分类器的权重。

3.对各类别的权重和,哪个权重和大,则当前数据分为该类。

4.如果分类结果和其本身的值一致,则说明分类正确,否则分类失败。

正确率的计算

1.对数据集中所有数据进行上一步的分类结果判断。

2.正确率 = 正确分类个数/总数。

数据集(以iris.arff存储,并存储在D盘下的data文件夹内):

@RELATION iris

@ATTRIBUTE sepallength	REAL
@ATTRIBUTE sepalwidth 	REAL
@ATTRIBUTE petallength 	REAL
@ATTRIBUTE petalwidth	REAL
@ATTRIBUTE class 	{Iris-setosa,Iris-versicolor,Iris-virginica}

@DATA
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica

代码:

package 日撸Java300行_61_70;


import java.io.FileReader;
import weka.core.Instance;
import weka.core.Instances;

/**
 * The booster which ensembles base classifiers.
 * 
 * @author Hui Xiao
 */
public class Booster {

	/**
	 * Classifiers.
	 */
	SimpleClassifier[] classifiers;

	/**
	 * Number of classifiers.
	 */
	int numClassifiers;

	/**
	 * Whether or not stop after the training error is 0.
	 */
	boolean stopAfterConverge = false;

	/**
	 * The weights of classifiers.
	 */
	double[] classifierWeights;

	/**
	 * The training data.
	 */
	Instances trainingData;

	/**
	 * The testing data.
	 */
	Instances testingData;

	/**
	 ****************** 
	 * The first constructor. The testing set is the same as the training set.
	 * 
	 * @param paraTrainingFilename
	 *            The data filename.
	 ****************** 
	 */
	public Booster(String paraTrainingFilename) {
		// Step 1. Read training set.
		try {
			FileReader tempFileReader = new FileReader(paraTrainingFilename);
			trainingData = new Instances(tempFileReader);
			tempFileReader.close();
		} catch (Exception ee) {
			System.out.println("Cannot read the file: " + paraTrainingFilename + "\r\n" + ee);
			System.exit(0);
		} // Of try

		// Step 2. Set the last attribute as the class index.
		trainingData.setClassIndex(trainingData.numAttributes() - 1);

		// Step 3. The testing data is the same as the training data.
		testingData = trainingData;

		stopAfterConverge = true;

		System.out.println("****************Data**********\r\n" + trainingData);
	}// Of the first constructor

	/**
	 ****************** 
	 * Set the number of base classifier, and allocate space for them.
	 * 
	 * @param paraNumBaseClassifiers
	 *            The number of base classifier.
	 ****************** 
	 */
	public void setNumBaseClassifiers(int paraNumBaseClassifiers) {
		numClassifiers = paraNumBaseClassifiers;

		// Step 1. Allocate space (only reference) for classifiers
		classifiers = new SimpleClassifier[numClassifiers];

		// Step 2. Initialize classifier weights.
		classifierWeights = new double[numClassifiers];
	}// Of setNumBaseClassifiers

	/**
	 ****************** 
	 * Train the booster.
	 * 
	 * @see algorithm.StumpClassifier#train()
	 ****************** 
	 */
	public void train() {
		// Step 1. Initialize.
		WeightedInstances tempWeightedInstances = null;
		double tempError;
		numClassifiers = 0;

		// Step 2. Build other classifiers.
		for (int i = 0; i < classifiers.length; i++) {
			// Step 2.1 Key code: Construct or adjust the weightedInstances
			if (i == 0) {
				tempWeightedInstances = new WeightedInstances(trainingData);
			} else {
				// Adjust the weights of the data.
				tempWeightedInstances.adjustWeights(classifiers[i - 1].computeCorrectnessArray(),
						classifierWeights[i - 1]);
			} // Of if

			// Step 2.2 Train the next classifier.
			classifiers[i] = new StumpClassifier(tempWeightedInstances);
			classifiers[i].train();

			tempError = classifiers[i].computeWeightedError();

			// Key code: Set the classifier weight.
			classifierWeights[i] = 0.5 * Math.log(1 / tempError - 1);
			if (classifierWeights[i] < 1e-6) {
				classifierWeights[i] = 0;
			} // Of if

			System.out.println("Classifier #" + i + " , weighted error = " + tempError + ", weight = "
					+ classifierWeights[i] + "\r\n");

			numClassifiers++;

			// The accuracy is enough.
			if (stopAfterConverge) {
				double tempTrainingAccuracy = computeTrainingAccuray();
				System.out.println("The accuracy of the booster is: " + tempTrainingAccuracy + "\r\n");
				if (tempTrainingAccuracy > 0.999999) {
					System.out.println("Stop at the round: " + i + " due to converge.\r\n");
					break;
				} // Of if
			} // Of if
		} // Of for i
	}// Of train

	/**
	 ****************** 
	 * Classify an instance.
	 * 
	 * @param paraInstance
	 *            The given instance.
	 * @return The predicted label.
	 ****************** 
	 */
	public int classify(Instance paraInstance) {
		double[] tempLabelsCountArray = new double[trainingData.classAttribute().numValues()];
		for (int i = 0; i < numClassifiers; i++) {
			int tempLabel = classifiers[i].classify(paraInstance);
			tempLabelsCountArray[tempLabel] += classifierWeights[i];
		} // Of for i

		int resultLabel = -1;
		double tempMax = -1;
		for (int i = 0; i < tempLabelsCountArray.length; i++) {
			if (tempMax < tempLabelsCountArray[i]) {
				tempMax = tempLabelsCountArray[i];
				resultLabel = i;
			} // Of if
		} // Of for

		return resultLabel;
	}// Of classify

	/**
	 ****************** 
	 * Test the booster on the training data.
	 * 
	 * @return The classification accuracy.
	 ****************** 
	 */
	public double test() {
		System.out.println("Testing on " + testingData.numInstances() + " instances.\r\n");

		return test(testingData);
	}// Of test

	/**
	 ****************** 
	 * Test the booster.
	 * 
	 * @param paraInstances
	 *            The testing set.
	 * @return The classification accuracy.
	 ****************** 
	 */
	public double test(Instances paraInstances) {
		double tempCorrect = 0;
		paraInstances.setClassIndex(paraInstances.numAttributes() - 1);

		for (int i = 0; i < paraInstances.numInstances(); i++) {
			Instance tempInstance = paraInstances.instance(i);
			if (classify(tempInstance) == (int) tempInstance.classValue()) {
				tempCorrect++;
			} // Of if
		} // Of for i

		double resultAccuracy = tempCorrect / paraInstances.numInstances();
		System.out.println("The accuracy is: " + resultAccuracy);

		return resultAccuracy;
	} // Of test

	/**
	 ****************** 
	 * Compute the training accuracy of the booster. It is not weighted.
	 * 
	 * @return The training accuracy.
	 ****************** 
	 */
	public double computeTrainingAccuray() {
		double tempCorrect = 0;

		for (int i = 0; i < trainingData.numInstances(); i++) {
			if (classify(trainingData.instance(i)) == (int) trainingData.instance(i).classValue()) {
				tempCorrect++;
			} // Of if
		} // Of for i

		double tempAccuracy = tempCorrect / trainingData.numInstances();

		return tempAccuracy;
	}// Of computeTrainingAccuray

	/**
	 ****************** 
	 * For integration test.
	 * 
	 * @param args
	 *            Not provided.
	 ****************** 
	 */
	public static void main(String args[]) {
		System.out.println("Starting AdaBoosting...");
		Booster tempBooster = new Booster("D:/data/iris.arff");

		tempBooster.setNumBaseClassifiers(100);
		tempBooster.train();

		System.out.println("The training accuracy is: " + tempBooster.computeTrainingAccuray());
		tempBooster.test();
	}// Of main

}// Of class Booster

截图:

 说明准确率达到了98%。

文章出处登录后可见!

已经登录?立即刷新

共计人评分,平均

到目前为止还没有投票!成为第一位评论此文章。

(0)
社会演员多的头像社会演员多普通用户
上一篇 2022年5月20日
下一篇 2022年5月20日

相关推荐