aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-11-19 22:02:17 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-19 22:02:17 -0800
commit3b7f056da87a23f3a96f0311b3a947a9b698f38b (patch)
tree076f2f6abaa2992b91a35da9aaf9a8152fdb41b2
parent4114ce20fbe820f111e55e891ae3889b0e6e0006 (diff)
downloadspark-3b7f056da87a23f3a96f0311b3a947a9b698f38b.tar.gz
spark-3b7f056da87a23f3a96f0311b3a947a9b698f38b.tar.bz2
spark-3b7f056da87a23f3a96f0311b3a947a9b698f38b.zip
[SPARK-11829][ML] Add read/write to estimators under ml.feature (II)
Add read/write support to the following estimators under spark.ml: * ChiSqSelector * PCA * VectorIndexer * Word2Vec Author: Yanbo Liang <ybliang8@gmail.com> Closes #9838 from yanboliang/spark-11829.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala65
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala67
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala66
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala67
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala22
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala26
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala22
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala30
9 files changed, 338 insertions, 33 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 5e4061fba5..dfec03828f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -17,13 +17,14 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute.{AttributeGroup, _}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
-import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -60,7 +61,7 @@ private[feature] trait ChiSqSelectorParams extends Params
*/
@Experimental
final class ChiSqSelector(override val uid: String)
- extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams {
+ extends Estimator[ChiSqSelectorModel] with ChiSqSelectorParams with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("chiSqSelector"))
@@ -95,6 +96,13 @@ final class ChiSqSelector(override val uid: String)
override def copy(extra: ParamMap): ChiSqSelector = defaultCopy(extra)
}
+@Since("1.6.0")
+object ChiSqSelector extends DefaultParamsReadable[ChiSqSelector] {
+
+ @Since("1.6.0")
+ override def load(path: String): ChiSqSelector = super.load(path)
+}
+
/**
* :: Experimental ::
* Model fitted by [[ChiSqSelector]].
@@ -103,7 +111,12 @@ final class ChiSqSelector(override val uid: String)
final class ChiSqSelectorModel private[ml] (
override val uid: String,
private val chiSqSelector: feature.ChiSqSelectorModel)
- extends Model[ChiSqSelectorModel] with ChiSqSelectorParams {
+ extends Model[ChiSqSelectorModel] with ChiSqSelectorParams with MLWritable {
+
+ import ChiSqSelectorModel._
+
+ /** list of indices to select (filter). Must be ordered asc */
+ val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures
/** @group setParam */
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
@@ -147,4 +160,46 @@ final class ChiSqSelectorModel private[ml] (
val copied = new ChiSqSelectorModel(uid, chiSqSelector)
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("1.6.0")
+ override def write: MLWriter = new ChiSqSelectorModelWriter(this)
+}
+
+@Since("1.6.0")
+object ChiSqSelectorModel extends MLReadable[ChiSqSelectorModel] {
+
+ private[ChiSqSelectorModel]
+ class ChiSqSelectorModelWriter(instance: ChiSqSelectorModel) extends MLWriter {
+
+ private case class Data(selectedFeatures: Seq[Int])
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = Data(instance.selectedFeatures.toSeq)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class ChiSqSelectorModelReader extends MLReader[ChiSqSelectorModel] {
+
+ private val className = classOf[ChiSqSelectorModel].getName
+
+ override def load(path: String): ChiSqSelectorModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath).select("selectedFeatures").head()
+ val selectedFeatures = data.getAs[Seq[Int]](0).toArray
+ val oldModel = new feature.ChiSqSelectorModel(selectedFeatures)
+ val model = new ChiSqSelectorModel(metadata.uid, oldModel)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+
+ @Since("1.6.0")
+ override def read: MLReader[ChiSqSelectorModel] = new ChiSqSelectorModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): ChiSqSelectorModel = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 539084704b..32d7afee6e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -17,13 +17,15 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.mllib.linalg._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
@@ -49,7 +51,8 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC
* PCA trains a model to project vectors to a low-dimensional space using PCA.
*/
@Experimental
-class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams {
+class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
+ with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("pca"))
@@ -86,6 +89,13 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
override def copy(extra: ParamMap): PCA = defaultCopy(extra)
}
+@Since("1.6.0")
+object PCA extends DefaultParamsReadable[PCA] {
+
+ @Since("1.6.0")
+ override def load(path: String): PCA = super.load(path)
+}
+
/**
* :: Experimental ::
* Model fitted by [[PCA]].
@@ -94,7 +104,12 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
class PCAModel private[ml] (
override val uid: String,
pcaModel: feature.PCAModel)
- extends Model[PCAModel] with PCAParams {
+ extends Model[PCAModel] with PCAParams with MLWritable {
+
+ import PCAModel._
+
+ /** a principal components Matrix. Each column is one principal component. */
+ val pc: DenseMatrix = pcaModel.pc
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -127,4 +142,46 @@ class PCAModel private[ml] (
val copied = new PCAModel(uid, pcaModel)
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("1.6.0")
+ override def write: MLWriter = new PCAModelWriter(this)
+}
+
+@Since("1.6.0")
+object PCAModel extends MLReadable[PCAModel] {
+
+ private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter {
+
+ private case class Data(k: Int, pc: DenseMatrix)
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = Data(instance.getK, instance.pc)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class PCAModelReader extends MLReader[PCAModel] {
+
+ private val className = classOf[PCAModel].getName
+
+ override def load(path: String): PCAModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
+ .select("k", "pc")
+ .head()
+ val oldModel = new feature.PCAModel(k, pc)
+ val model = new PCAModel(metadata.uid, oldModel)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+
+ @Since("1.6.0")
+ override def read: MLReader[PCAModel] = new PCAModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): PCAModel = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 52e0599e38..a637a6f288 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -22,12 +22,14 @@ import java.util.{Map => JMap}
import scala.collection.JavaConverters._
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute._
-import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params}
+import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.udf
@@ -93,7 +95,7 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu
*/
@Experimental
class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerModel]
- with VectorIndexerParams {
+ with VectorIndexerParams with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("vecIdx"))
@@ -136,7 +138,11 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
override def copy(extra: ParamMap): VectorIndexer = defaultCopy(extra)
}
-private object VectorIndexer {
+@Since("1.6.0")
+object VectorIndexer extends DefaultParamsReadable[VectorIndexer] {
+
+ @Since("1.6.0")
+ override def load(path: String): VectorIndexer = super.load(path)
/**
* Helper class for tracking unique values for each feature.
@@ -146,7 +152,7 @@ private object VectorIndexer {
* @param numFeatures This class fails if it encounters a Vector whose length is not numFeatures.
* @param maxCategories This class caps the number of unique values collected at maxCategories.
*/
- class CategoryStats(private val numFeatures: Int, private val maxCategories: Int)
+ private class CategoryStats(private val numFeatures: Int, private val maxCategories: Int)
extends Serializable {
/** featureValueSets[feature index] = set of unique values */
@@ -252,7 +258,9 @@ class VectorIndexerModel private[ml] (
override val uid: String,
val numFeatures: Int,
val categoryMaps: Map[Int, Map[Double, Int]])
- extends Model[VectorIndexerModel] with VectorIndexerParams {
+ extends Model[VectorIndexerModel] with VectorIndexerParams with MLWritable {
+
+ import VectorIndexerModel._
/** Java-friendly version of [[categoryMaps]] */
def javaCategoryMaps: JMap[JInt, JMap[JDouble, JInt]] = {
@@ -408,4 +416,48 @@ class VectorIndexerModel private[ml] (
val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps)
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("1.6.0")
+ override def write: MLWriter = new VectorIndexerModelWriter(this)
+}
+
+@Since("1.6.0")
+object VectorIndexerModel extends MLReadable[VectorIndexerModel] {
+
+ private[VectorIndexerModel]
+ class VectorIndexerModelWriter(instance: VectorIndexerModel) extends MLWriter {
+
+ private case class Data(numFeatures: Int, categoryMaps: Map[Int, Map[Double, Int]])
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = Data(instance.numFeatures, instance.categoryMaps)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class VectorIndexerModelReader extends MLReader[VectorIndexerModel] {
+
+ private val className = classOf[VectorIndexerModel].getName
+
+ override def load(path: String): VectorIndexerModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath)
+ .select("numFeatures", "categoryMaps")
+ .head()
+ val numFeatures = data.getAs[Int](0)
+ val categoryMaps = data.getAs[Map[Int, Map[Double, Int]]](1)
+ val model = new VectorIndexerModel(metadata.uid, numFeatures, categoryMaps)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+
+ @Since("1.6.0")
+ override def read: MLReader[VectorIndexerModel] = new VectorIndexerModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): VectorIndexerModel = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 708dbeef84..a8d61b6dea 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -17,15 +17,17 @@
package org.apache.spark.ml.feature
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.SparkContext
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -90,7 +92,8 @@ private[feature] trait Word2VecBase extends Params
* natural language processing or machine learning process.
*/
@Experimental
-final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase {
+final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] with Word2VecBase
+ with DefaultParamsWritable {
def this() = this(Identifiable.randomUID("w2v"))
@@ -139,6 +142,13 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
override def copy(extra: ParamMap): Word2Vec = defaultCopy(extra)
}
+@Since("1.6.0")
+object Word2Vec extends DefaultParamsReadable[Word2Vec] {
+
+ @Since("1.6.0")
+ override def load(path: String): Word2Vec = super.load(path)
+}
+
/**
* :: Experimental ::
* Model fitted by [[Word2Vec]].
@@ -147,7 +157,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
class Word2VecModel private[ml] (
override val uid: String,
@transient private val wordVectors: feature.Word2VecModel)
- extends Model[Word2VecModel] with Word2VecBase {
+ extends Model[Word2VecModel] with Word2VecBase with MLWritable {
+
+ import Word2VecModel._
/**
* Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
@@ -224,4 +236,49 @@ class Word2VecModel private[ml] (
val copied = new Word2VecModel(uid, wordVectors)
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("1.6.0")
+ override def write: MLWriter = new Word2VecModelWriter(this)
+}
+
+@Since("1.6.0")
+object Word2VecModel extends MLReadable[Word2VecModel] {
+
+ private[Word2VecModel]
+ class Word2VecModelWriter(instance: Word2VecModel) extends MLWriter {
+
+ private case class Data(wordIndex: Map[String, Int], wordVectors: Seq[Float])
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = Data(instance.wordVectors.wordIndex, instance.wordVectors.wordVectors.toSeq)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class Word2VecModelReader extends MLReader[Word2VecModel] {
+
+ private val className = classOf[Word2VecModel].getName
+
+ override def load(path: String): Word2VecModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath)
+ .select("wordIndex", "wordVectors")
+ .head()
+ val wordIndex = data.getAs[Map[String, Int]](0)
+ val wordVectors = data.getAs[Seq[Float]](1).toArray
+ val oldModel = new feature.Word2VecModel(wordIndex, wordVectors)
+ val model = new Word2VecModel(metadata.uid, oldModel)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+
+ @Since("1.6.0")
+ override def read: MLReader[Word2VecModel] = new Word2VecModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): Word2VecModel = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 7ab0d89d23..a47f27b0af 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -432,9 +432,9 @@ class Word2Vec extends Serializable with Logging {
* (i * vectorSize, i * vectorSize + vectorSize)
*/
@Since("1.1.0")
-class Word2VecModel private[mllib] (
- private val wordIndex: Map[String, Int],
- private val wordVectors: Array[Float]) extends Serializable with Saveable {
+class Word2VecModel private[spark] (
+ private[spark] val wordIndex: Map[String, Int],
+ private[spark] val wordVectors: Array[Float]) extends Serializable with Saveable {
private val numWords = wordIndex.size
// vectorSize: Dimension of each word's vector.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
index e5a42967bd..7827db2794 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala
@@ -18,13 +18,17 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext}
-class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
+class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
+
test("Test Chi-Square selector") {
val sqlContext = SQLContext.getOrCreate(sc)
import sqlContext.implicits._
@@ -58,4 +62,20 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(vec1 ~== vec2 absTol 1e-1)
}
}
+
+ test("ChiSqSelector read/write") {
+ val t = new ChiSqSelector()
+ .setFeaturesCol("myFeaturesCol")
+ .setLabelCol("myLabelCol")
+ .setOutputCol("myOutputCol")
+ .setNumTopFeatures(2)
+ testDefaultReadWrite(t)
+ }
+
+ test("ChiSqSelectorModel read/write") {
+ val oldModel = new feature.ChiSqSelectorModel(Array(1, 3))
+ val instance = new ChiSqSelectorModel("myChiSqSelectorModel", oldModel)
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.selectedFeatures === instance.selectedFeatures)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
index 30c500f87a..5a21cd20ce 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
@@ -19,15 +19,15 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.distributed.RowMatrix
-import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices}
+import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel}
import org.apache.spark.sql.Row
-class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
+class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new PCA)
@@ -65,4 +65,24 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
}
}
+
+ test("read/write") {
+
+ def checkModelData(model1: PCAModel, model2: PCAModel): Unit = {
+ assert(model1.pc === model2.pc)
+ }
+ val allParams: Map[String, Any] = Map(
+ "k" -> 3,
+ "inputCol" -> "features",
+ "outputCol" -> "pca_features"
+ )
+ val data = Seq(
+ (0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))),
+ (1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)),
+ (2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
+ )
+ val df = sqlContext.createDataFrame(data).toDF("id", "features")
+ val pca = new PCA().setK(3)
+ testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 8cb0a2cf14..67817fa4ba 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -22,13 +22,14 @@ import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.{Logging, SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
+class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest with Logging {
import VectorIndexerSuite.FeatureData
@@ -251,6 +252,23 @@ class VectorIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with L
}
}
}
+
+ test("VectorIndexer read/write") {
+ val t = new VectorIndexer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setMaxCategories(30)
+ testDefaultReadWrite(t)
+ }
+
+ test("VectorIndexerModel read/write") {
+ val categoryMaps = Map(0 -> Map(0.0 -> 0, 1.0 -> 1), 1 -> Map(0.0 -> 0, 1.0 -> 1,
+ 2.0 -> 2, 3.0 -> 3), 2 -> Map(0.0 -> 0, -1.0 -> 1, 2.0 -> 2))
+ val instance = new VectorIndexerModel("myVectorIndexerModel", 3, categoryMaps)
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.numFeatures === instance.numFeatures)
+ assert(newInstance.categoryMaps === instance.categoryMaps)
+ }
}
private[feature] object VectorIndexerSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index 23dfdaa9f8..a773244cd7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -19,14 +19,14 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel}
-class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
+class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new Word2Vec)
@@ -143,5 +143,31 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+
+ test("Word2Vec read/write") {
+ val t = new Word2Vec()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setMaxIter(2)
+ .setMinCount(8)
+ .setNumPartitions(1)
+ .setSeed(42L)
+ .setStepSize(0.01)
+ .setVectorSize(100)
+ testDefaultReadWrite(t)
+ }
+
+ test("Word2VecModel read/write") {
+ val word2VecMap = Map(
+ ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
+ ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
+ ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
+ ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
+ )
+ val oldModel = new OldWord2VecModel(word2VecMap)
+ val instance = new Word2VecModel("myWord2VecModel", oldModel)
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.getVectors.collect() === instance.getVectors.collect())
+ }
}