博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Spark随机森林算法实践
阅读量:6610 次
发布时间:2019-06-24

本文共 5158 字,大约阅读时间需要 17 分钟。

hot3.png

    1. 例子1

object RunRF {

  def main(args: Array[String]) {

    val sparkConf = new SparkConf().setAppName("rf")

    val sc = new SparkContext(sparkConf)

    //读取数据

    val rawData = sc.textFile("hdfs://192.168.1.64:8020/test/mllib/v3.csv")

    val data = rawData.map{ line =>

      val values = line.split(",").map(_.toDouble)

      //init返回除了最后一个元素的所有元素,作为特征向量

      //Vectors.dense向量化,dense密集型

      val feature = Vectors.dense(values.init)

      val label = values.last

      LabeledPoint(label, feature)

    }

    //训练集、交叉验证集和测试集,各占80%10%10%

    //10%的交叉验证数据集的作用是确定在训练数据集上训练出来的模型的最好参数

    //测试数据集的作用是评估CV数据集的最好参数

    val Array(trainData, cvData, testData) = data.randomSplit(Array(0.8, 0.1, 0.1))

    trainData.cache()

    cvData.cache()

    testData.cache()

 

    //构建随机森林

    val model = RandomForest.trainClassifier(trainData, 2, Map[Int, Int](), 20, "auto", "gini", 4, 32)

    val metrics = getMetrics(model, cvData)

    println("-----------------------------------------confusionMatrix-----------------------------------------------------")

    //混淆矩阵和模型精确率

    println(metrics.confusionMatrix)

    println("---------------------------------------------precision-------------------------------------------------")

    println(metrics.precision)

 

    println("-----------------------------------------(precision,recall)---------------------------------------------------")

    //每个类别对应的精确率与召回率

    (0 until 2).map(target => (metrics.precision(target), metrics.recall(target))).foreach(println)

    //保存模型

    model.save(sc,"hdfs://192.168.1.64:8020/tmp/RFModel")

 

  }

 

/**

    * model 随机森林模型

    * data  用于交叉验证的数据集

    * */

  def getMetrics(model: RandomForestModel, data: RDD[LabeledPoint]): MulticlassMetrics = {

    //将交叉验证数据集的每个样本的特征向量交给模型预测,并和原本正确的目标特征组成一个tuple

    val predictionsAndLables = data.map { d =>

      (model.predict(d.features), d.label)

    }

    //将结果交给MulticlassMetrics,其可以以不同的方式计算分配器预测的质量

    new MulticlassMetrics(predictionsAndLables)

  }

 

/**

    * 在训练数据集上得到最好的参数组合

    * trainData 训练数据集

    * cvData 交叉验证数据集

    * */

  def getBestParam(trainData: RDD[LabeledPoint], cvData: RDD[LabeledPoint]): Unit = {

    val evaluations = for (impurity <- Array("gini", "entropy");

                           depth <- Array(1, 20);

                           bins <- Array(10, 300)) yield {

      val model = RandomForest.trainClassifier(trainData, 2, Map[Int, Int](), 20, "auto", impurity, depth, bins)

// 2classes

// 20: numTrees

// auto:subSampleStratry

      val metrics = getMetrics(model, cvData)

      ((impurity, depth, bins), metrics.precision)

    }

    evaluations.sortBy(_._2).reverse.foreach(println)

  }

 

/**

    * 模拟对新数据进行预测1

    */

  val rawData = sc.textFile("hdfs://192.168.1.64:8020/test/mllib/v3.csv")

 

  val pdata = rawData.map{ line =>

    val values = line.split(",").map(_.toDouble)

    //转化为向量并去掉标签(init去掉最后一个元素,即去掉标签)

    val feature = Vectors.dense(values.init)

    feature

  }

  //读取模型

  val rfModel = RandomForestModel.load(sc,"hdfs://192.168.1.64:8020/tmp/RFModel")

  //进行预测

  val preLabel = rfModel.predict(pdata)

  preLabel.take(10)

  /**

    * 模拟对新数据进行预测2

    *

    */

  val dataAndPreLable = rawData.map{ line =>

    //转化为向量并去掉标签(init去掉最后一个元素,即去掉标签)

    val vecData = Vectors.dense(line.split(",").map(_.toDouble).init)

    val preLabel = rfModel.predict(vecData)

    line + "," + preLabel

  }//.saveAsTextFile("....")

dataAndPreLable.take(10)

}

 

    1. 例子2:处理Hive数据

val hc = new HiveContext(sc)

    import hc.implicits._

    // 调用HiveContext

 

    // 取样本,样本的第一列为label0或者1),其他列可能是姓名,手机号,以及真正要参与训练的特征columns

    val data = hc.sql(s"""select  *  from database1.traindata_userprofile""".stripMargin)

    //提取schema,也就是表的column namedrop2)删掉1,2列,只保留特征列

 

    val schema = data.schema.map(f=>s"${f.name}").drop(2)

 

    //MLVectorAssembler是一个transformer,要求数据类型不能是string,将多列数据转化为单列的向量列,比如把ageincome等等字段列合并成一个 userFea 向量列,方便后续训练

    val assembler = new VectorAssembler().setInputCols(schema.toArray).setOutputCol("userFea")

    val userProfile = assembler.transform(data.na.fill(-1e9)).select("label","userFea")

    val data_train = userProfile.na.fill(-1e9)

    // 取训练样本

    val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(userProfile)

    val featureIndexer = new VectorIndexer().setInputCol("userFea").setOutputCol("indexedFeatures").setMaxCategories(4).fit(userProfile)

 

    // Split the data into training and test sets (30% held out for testing).

    val Array(trainingData, testData) = userProfile.randomSplit(Array(0.7, 0.3))

    // Train a RandomForest model.

    val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures")

    rf.setMaxBins(32).setMaxDepth(6).setNumTrees(90).setMinInstancesPerNode(4).setImpurity("gini")

    // Convert indexed labels back to original labels.

    val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)

 

    val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))

 

    // Train model. This also runs the indexers.

    val model = pipeline.fit(trainingData)

    println("training finished!!!!")

    // Make predictions.

    val predictions = model.transform(testData)

 

    // Select example rows to display.

    predictions.select("predictedLabel", "indexedLabel", "indexedFeatures").show(5)

 

    val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")

    val accuracy = evaluator.evaluate(predictions)

    println("Test Error = " + (1.0 - accuracy))

 

转载于:https://my.oschina.net/u/778683/blog/1831260

你可能感兴趣的文章
高淇java300集JAVA常用类作业
查看>>
<Linux命令行学习 第一节> CentOS在虚拟机的安装
查看>>
无Paper不论文
查看>>
mysql设置字符集CHARACTER SET
查看>>
redis 系列15 数据对象的(类型检查,内存回收,对象共享)和数据库切换
查看>>
log框架集成
查看>>
python命令行下安装redis客户端
查看>>
如何在Oracle中复制表结构和表数据
查看>>
[河南省ACM省赛-第四届] 序号互换 (nyoj 303)
查看>>
3 Oracle 32位客户端安装及arcgis连接
查看>>
[MFC] MFC编译程序,缺少MFC动态链接库的解决
查看>>
Sticker.js – 帮助你在网站中加入贴纸效果
查看>>
欧拉路与欧拉回路的性质
查看>>
iOS之UI--关于modal
查看>>
各种U启网启什么的都是浮云
查看>>
请问JDBC中IN语句怎么构建
查看>>
2015第52周六
查看>>
UIScrollView设置了contentSize后还是没办法滚动?
查看>>
POJ 1205 Water Treatment Plants(递推)
查看>>
国内外DNS服务器地址列表
查看>>