aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVinceShieh <vincent.xie@intel.com>2016-11-17 13:37:42 +0000
committerSean Owen <sowen@cloudera.com>2016-11-17 13:37:42 +0000
commitde77c67750dc868d75d6af173c3820b75a9fe4b7 (patch)
tree2ae4bc8e4f25e330c64071f9db3b5724dc3df9ca
parent49b6f456aca350e9e2c170782aa5cc75e7822680 (diff)
downloadspark-de77c67750dc868d75d6af173c3820b75a9fe4b7.tar.gz
spark-de77c67750dc868d75d6af173c3820b75a9fe4b7.tar.bz2
spark-de77c67750dc868d75d6af173c3820b75a9fe4b7.zip
[SPARK-17462][MLLIB]use VersionUtils to parse Spark version strings
## What changes were proposed in this pull request? Several places in MLlib use custom regexes or other approaches to parse Spark versions. Those should be fixed to use the VersionUtils. This PR replaces custom regexes with VersionUtils to get Spark version numbers. ## How was this patch tested? Existing tests. Signed-off-by: VinceShieh vincent.xieintel.com Author: VinceShieh <vincent.xie@intel.com> Closes #15055 from VinceShieh/SPARK-17462.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala6
2 files changed, 4 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index a0d481b294..26505b4cc1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -33,6 +33,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.util.VersionUtils.majorVersion
/**
* Common params for KMeans and KMeansModel
@@ -232,10 +233,7 @@ object KMeansModel extends MLReadable[KMeansModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val versionRegex = "([0-9]+)\\.(.+)".r
- val versionRegex(major, _) = metadata.sparkVersion
-
- val clusterCenters = if (major.toInt >= 2) {
+ val clusterCenters = if (majorVersion(metadata.sparkVersion) >= 2) {
val data: Dataset[Data] = sparkSession.read.parquet(dataPath).as[Data]
data.collect().sortBy(_.clusterIdx).map(_.clusterCenter).map(OldVectors.fromML)
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 444006fe1e..1e49352b85 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -34,6 +34,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.util.VersionUtils.majorVersion
/**
* Params for [[PCA]] and [[PCAModel]].
@@ -204,11 +205,8 @@ object PCAModel extends MLReadable[PCAModel] {
override def load(path: String): PCAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
- val versionRegex = "([0-9]+)\\.(.+)".r
- val versionRegex(major, _) = metadata.sparkVersion
-
val dataPath = new Path(path, "data").toString
- val model = if (major.toInt >= 2) {
+ val model = if (majorVersion(metadata.sparkVersion) >= 2) {
val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
sparkSession.read.parquet(dataPath)
.select("pc", "explainedVariance")