基于Spark实现随机森林代码
时间:2021-04-23 09:26:31|栏目:JAVA代码|点击: 次
本文实例为大家分享了基于Spark实现随机森林的具体代码,供大家参考,具体内容如下
public class RandomForestClassficationTest extends TestCase implements Serializable
{
/**
*
*/
private static final long serialVersionUID = 7802523720751354318L;
class PredictResult implements Serializable{
/**
*
*/
private static final long serialVersionUID = -168308887976477219L;
double label;
double prediction;
public PredictResult(double label,double prediction){
this.label = label;
this.prediction = prediction;
}
@Override
public String toString(){
return this.label + " : " + this.prediction ;
}
}
public void test_randomForest() throws JAXBException{
SparkConf sparkConf = new SparkConf();
sparkConf.setAppName("RandomForest");
sparkConf.setMaster("local");
SparkContext sc = new SparkContext(sparkConf);
String dataPath = RandomForestClassficationTest.class.getResource("/").getPath() + "/sample_libsvm_data.txt";
RDD dataSet = MLUtils.loadLibSVMFile(sc, dataPath);
RDD[] rddList = dataSet.randomSplit(new double[]{0.7,0.3},1);
RDD trainingData = rddList[0];
RDD testData = rddList[1];
ClassTag labelPointClassTag = trainingData.elementClassTag();
JavaRDD trainingJavaData = new JavaRDD(trainingData,labelPointClassTag);
int numClasses = 2;
Map categoricalFeatureInfos = new HashMap();
int numTrees = 3;
String featureSubsetStrategy = "auto";
String impurity = "gini";
int maxDepth = 4;
int maxBins = 32;
/**
* 1 numClasses分类个数为2
* 2 numTrees 表示的是随机森林中树的个数
* 3 featureSubsetStrategy
* 4
*/
final RandomForestModel model = RandomForest.trainClassifier(trainingJavaData,
numClasses,
categoricalFeatureInfos,
numTrees,
featureSubsetStrategy,
impurity,
maxDepth,
maxBins,
1);
JavaRDD testJavaData = new JavaRDD(testData,testData.elementClassTag());
JavaRDD predictRddResult = testJavaData.map(new Function(){
/**
*
*/
private static final long serialVersionUID = 1L;
public PredictResult call(LabeledPoint point) throws Exception {
// TODO Auto-generated method stub
double pointLabel = point.label();
double prediction = model.predict(point.features());
PredictResult result = new PredictResult(pointLabel,prediction);
return result;
}
});
List predictResultList = predictRddResult.collect();
for(PredictResult result:predictResultList){
System.out.println(result.toString());
}
System.out.println(model.toDebugString());
}
}
得到的随机森林的展示结果如下:
TreeEnsembleModel classifier with 3 trees Tree 0: If (feature 435 <= 0.0) If (feature 516 <= 0.0) Predict: 0.0 Else (feature 516 > 0.0) Predict: 1.0 Else (feature 435 > 0.0) Predict: 1.0 Tree 1: If (feature 512 <= 0.0) Predict: 1.0 Else (feature 512 > 0.0) Predict: 0.0 Tree 2: If (feature 377 <= 1.0) Predict: 0.0 Else (feature 377 > 1.0) If (feature 455 <= 0.0) Predict: 1.0 Else (feature 455 > 0.0) Predict: 0.0
栏 目:JAVA代码
下一篇:详解Guava Cache本地缓存在Spring Boot应用中的实践
本文标题:基于Spark实现随机森林代码
本文地址:http://www.codeinn.net/misctech/106753.html


阅读排行
- 1Java Swing组件BoxLayout布局用法示例
- 2java中-jar 与nohup的对比
- 3Java邮件发送程序(可以同时发给多个地址、可以带附件)
- 4Caused by: java.lang.ClassNotFoundException: org.objectweb.asm.Type异常
- 5Java中自定义异常详解及实例代码
- 6深入理解Java中的克隆
- 7java读取excel文件的两种方法
- 8解析SpringSecurity+JWT认证流程实现
- 9spring boot里增加表单验证hibernate-validator并在freemarker模板里显示错误信息(推荐)
- 10深入解析java虚拟机




