本地测试Spark的svm算法

发布时间 2023-09-26 18:02:23作者: 人不疯狂枉一生

上一篇介绍了逻辑回归算法,发现分类效果不好,通过这次的svm发现是因为训练数据不行,于是网上找了部分训练数据,发现实际上分类效果还可以。

    训练数据,第一个值是标签,下面的数据是某种花的相关特征。

1|5.1,3.5,1.4,0.2
1|4.9,3,1.4,0.2
1|4.7,3.2,1.3,0.2
1|4.6,3.1,1.5,0.2
1|5,3.6,1.4,0.2
1|5.4,3.9,1.7,0.4
1|4.6,3.4,1.4,0.3
1|5,3.4,1.5,0.2
1|4.4,2.9,1.4,0.2
1|4.9,3.1,1.5,0.1
1|5.4,3.7,1.5,0.2
1|4.8,3.4,1.6,0.2
1|4.8,3,1.4,0.1
1|4.3,3,1.1,0.1
1|5.8,4,1.2,0.2
1|5.7,4.4,1.5,0.4
1|5.4,3.9,1.3,0.4
1|5.1,3.5,1.4,0.3
1|5.7,3.8,1.7,0.3
1|5.1,3.8,1.5,0.3
1|5.4,3.4,1.7,0.2
1|5.1,3.7,1.5,0.4
1|4.6,3.6,1,0.2
1|5.1,3.3,1.7,0.5
1|4.8,3.4,1.9,0.2
0|7,3.2,4.7,1.4
0|6.4,3.2,4.5,1.5
0|6.9,3.1,4.9,1.5
0|5.5,2.3,4,1.3
0|6.5,2.8,4.6,1.5
0|5.7,2.8,4.5,1.3
0|6.3,3.3,4.7,1.6
0|4.9,2.4,3.3,1
0|6.6,2.9,4.6,1.3
0|5.2,2.7,3.9,1.4
0|5,2,3.5,1
0|5.9,3,4.2,1.5
0|6,2.2,4,1
0|6.1,2.9,4.7,1.4
0|5.6,2.9,3.6,1.3

  测试数据如下。

0|5.1,2.5,3,1.1
0|5.7,2.8,4.1,1.3
1|5,3,1.6,0.2
1|5,3.4,1.6,0.4

    svm代码跟逻辑回归类似,只需替换算法即可。

import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.optimization.L1Updater
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.{SparkConf, SparkContext}

object TestSvmAlgorithm {
  def main(args: Array[String]): Unit = {
    val sparkConf = new SparkConf().setAppName("svm").setMaster("local").set("spark.testing.memory", "2147480000")
    val sparkContext = new SparkContext(sparkConf)
    val dataSpark = sparkContext.textFile("file:///D:\\var\\11.txt")
    val trainData = dataSpark.map(line => {
      val tmpLine = line.split("\\|")
      println("数据:" + tmpLine(0).toDouble, Vectors.dense(tmpLine(1).split("\\,").map(_.toDouble)))
      LabeledPoint(tmpLine(0).toDouble, Vectors.dense(tmpLine(1).split("\\,").map(_.toDouble)))
    }).cache()

    val iterationNum = 20
    //    val model = SVMWithSGD.train(trainData, iterationNum)
    val svmModel = new SVMWithSGD()
    svmModel.optimizer.setNumIterations(10).setRegParam(0.1).setUpdater(new L1Updater())
    val model = svmModel.run(trainData)
    val predictData = Vectors.dense(6.6,3,4.4,1.4)
    println(predictData)
    val result = model.predict(predictData)
    println(result)

    val labelAndPredicts = trainData.map(p => {
      val predi = model.predict(p.features)
      println("预测" + (p.label, predi))
      (p.label, predi)
    })

    val mericTest = new BinaryClassificationMetrics(labelAndPredicts)
    val auRoc = mericTest.areaUnderROC()
    println(":" + auRoc)
  }
}