aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYuhao Yang <yuhao.yang@intel.com>2016-06-23 21:50:25 -0700
committerXiangrui Meng <meng@databricks.com>2016-06-23 21:50:25 -0700
commitcc6778ee0bf4fa7a78abd30542c4a6f80ea371c5 (patch)
tree72c2479fc8ade96743428c3b2ad376d1c7770358 /mllib
parent4a40d43bb29704734b8128bf2a3f27802ae34e17 (diff)
downloadspark-cc6778ee0bf4fa7a78abd30542c4a6f80ea371c5.tar.gz
spark-cc6778ee0bf4fa7a78abd30542c4a6f80ea371c5.tar.bz2
spark-cc6778ee0bf4fa7a78abd30542c4a6f80ea371c5.zip
[SPARK-16133][ML] model loading backward compatibility for ml.feature
## What changes were proposed in this pull request? model loading backward compatibility for ml.feature, ## How was this patch tested? existing ut and manual test for loading 1.6 models. Author: Yuhao Yang <yuhao.yang@intel.com> Author: Yuhao Yang <hhbyyh@gmail.com> Closes #13844 from hhbyyh/featureComp.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala4
3 files changed, 11 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 02d4e6a9f7..5d6287f0e3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -27,6 +27,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
+import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
@@ -180,9 +181,9 @@ object IDFModel extends MLReadable[IDFModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
+ val Row(idf: Vector) = MLUtils.convertVectorColumnsToML(data, "idf")
.select("idf")
.head()
- val idf = data.getAs[Vector](0)
val model = new IDFModel(metadata.uid, new feature.IDFModel(OldVectors.fromML(idf)))
DefaultParamsReader.getAndSetParams(model, metadata)
model
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 562b3f38e4..d5ad5abced 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -28,6 +28,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.stat.Statistics
+import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
@@ -232,9 +233,11 @@ object MinMaxScalerModel extends MLReadable[MinMaxScalerModel] {
override def load(path: String): MinMaxScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val Row(originalMin: Vector, originalMax: Vector) = sparkSession.read.parquet(dataPath)
- .select("originalMin", "originalMax")
- .head()
+ val data = sparkSession.read.parquet(dataPath)
+ val Row(originalMin: Vector, originalMax: Vector) =
+ MLUtils.convertVectorColumnsToML(data, "originalMin", "originalMax")
+ .select("originalMin", "originalMax")
+ .head()
val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax)
DefaultParamsReader.getAndSetParams(model, metadata)
model
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index be58dc27e0..b4be95494f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -28,6 +28,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
+import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
@@ -211,7 +212,8 @@ object StandardScalerModel extends MLReadable[StandardScalerModel] {
override def load(path: String): StandardScalerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val Row(std: Vector, mean: Vector) = sparkSession.read.parquet(dataPath)
+ val data = sparkSession.read.parquet(dataPath)
+ val Row(std: Vector, mean: Vector) = MLUtils.convertVectorColumnsToML(data, "std", "mean")
.select("std", "mean")
.head()
val model = new StandardScalerModel(metadata.uid, std, mean)