aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-11-18 15:47:49 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-18 15:47:49 -0800
commit7e987de1770f4ab3d54bc05db8de0a1ef035941d (patch)
tree856cbb3cf219827d4022b40675e3b79300ed91e1 /mllib
parent5df08949f5d9e5b4b0e9c2db50c1b4eb93383de3 (diff)
downloadspark-7e987de1770f4ab3d54bc05db8de0a1ef035941d.tar.gz
spark-7e987de1770f4ab3d54bc05db8de0a1ef035941d.tar.bz2
spark-7e987de1770f4ab3d54bc05db8de0a1ef035941d.zip
[SPARK-6787][ML] add read/write to estimators under ml.feature (1)
Add read/write support to the following estimators under spark.ml: * CountVectorizer * IDF * MinMaxScaler * StandardScaler (a little awkward because we store some params in spark.mllib model) * StringIndexer Added some necessary method for read/write. Maybe we should add `private[ml] trait DefaultParamsReadable` and `DefaultParamsWritable` to save some boilerplate code, though we still need to override `load` for Java compatibility. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #9798 from mengxr/SPARK-6787.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala72
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala71
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala72
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala78
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala70
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala24
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala19
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala25
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala64
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala19
10 files changed, 467 insertions, 47 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 49028e4b85..5ff9bfb7d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -16,17 +16,19 @@
*/
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
-import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.DataFrame
import org.apache.spark.util.collection.OpenHashMap
/**
@@ -105,7 +107,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit
*/
@Experimental
class CountVectorizer(override val uid: String)
- extends Estimator[CountVectorizerModel] with CountVectorizerParams {
+ extends Estimator[CountVectorizerModel] with CountVectorizerParams with Writable {
def this() = this(Identifiable.randomUID("cntVec"))
@@ -169,6 +171,19 @@ class CountVectorizer(override val uid: String)
}
override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object CountVectorizer extends Readable[CountVectorizer] {
+
+ @Since("1.6.0")
+ override def read: Reader[CountVectorizer] = new DefaultParamsReader
+
+ @Since("1.6.0")
+ override def load(path: String): CountVectorizer = super.load(path)
}
/**
@@ -178,7 +193,9 @@ class CountVectorizer(override val uid: String)
*/
@Experimental
class CountVectorizerModel(override val uid: String, val vocabulary: Array[String])
- extends Model[CountVectorizerModel] with CountVectorizerParams {
+ extends Model[CountVectorizerModel] with CountVectorizerParams with Writable {
+
+ import CountVectorizerModel._
def this(vocabulary: Array[String]) = {
this(Identifiable.randomUID("cntVecModel"), vocabulary)
@@ -232,4 +249,47 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent)
copyValues(copied, extra)
}
+
+ @Since("1.6.0")
+ override def write: Writer = new CountVectorizerModelWriter(this)
+}
+
+@Since("1.6.0")
+object CountVectorizerModel extends Readable[CountVectorizerModel] {
+
+ private[CountVectorizerModel]
+ class CountVectorizerModelWriter(instance: CountVectorizerModel) extends Writer {
+
+ private case class Data(vocabulary: Seq[String])
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = Data(instance.vocabulary)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class CountVectorizerModelReader extends Reader[CountVectorizerModel] {
+
+ private val className = "org.apache.spark.ml.feature.CountVectorizerModel"
+
+ override def load(path: String): CountVectorizerModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath)
+ .select("vocabulary")
+ .head()
+ val vocabulary = data.getAs[Seq[String]](0).toArray
+ val model = new CountVectorizerModel(metadata.uid, vocabulary)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+
+ @Since("1.6.0")
+ override def read: Reader[CountVectorizerModel] = new CountVectorizerModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): CountVectorizerModel = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index 4c36df75d8..53ad34ef12 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -17,11 +17,13 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
@@ -60,7 +62,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
* Compute the Inverse Document Frequency (IDF) given a collection of documents.
*/
@Experimental
-final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase {
+final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBase with Writable {
def this() = this(Identifiable.randomUID("idf"))
@@ -85,6 +87,19 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
}
override def copy(extra: ParamMap): IDF = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object IDF extends Readable[IDF] {
+
+ @Since("1.6.0")
+ override def read: Reader[IDF] = new DefaultParamsReader
+
+ @Since("1.6.0")
+ override def load(path: String): IDF = super.load(path)
}
/**
@@ -95,7 +110,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
class IDFModel private[ml] (
override val uid: String,
idfModel: feature.IDFModel)
- extends Model[IDFModel] with IDFBase {
+ extends Model[IDFModel] with IDFBase with Writable {
+
+ import IDFModel._
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -117,4 +134,50 @@ class IDFModel private[ml] (
val copied = new IDFModel(uid, idfModel)
copyValues(copied, extra).setParent(parent)
}
+
+ /** Returns the IDF vector. */
+ @Since("1.6.0")
+ def idf: Vector = idfModel.idf
+
+ @Since("1.6.0")
+ override def write: Writer = new IDFModelWriter(this)
+}
+
+@Since("1.6.0")
+object IDFModel extends Readable[IDFModel] {
+
+ private[IDFModel] class IDFModelWriter(instance: IDFModel) extends Writer {
+
+ private case class Data(idf: Vector)
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = Data(instance.idf)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class IDFModelReader extends Reader[IDFModel] {
+
+ private val className = "org.apache.spark.ml.feature.IDFModel"
+
+ override def load(path: String): IDFModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath)
+ .select("idf")
+ .head()
+ val idf = data.getAs[Vector](0)
+ val model = new IDFModel(metadata.uid, new feature.IDFModel(idf))
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+
+ @Since("1.6.0")
+ override def read: Reader[IDFModel] = new IDFModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): IDFModel = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 1b494ec8b1..24d964fae8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -17,11 +17,14 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params}
-import org.apache.spark.ml.util.Identifiable
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params}
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.sql._
@@ -85,7 +88,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
*/
@Experimental
class MinMaxScaler(override val uid: String)
- extends Estimator[MinMaxScalerModel] with MinMaxScalerParams {
+ extends Estimator[MinMaxScalerModel] with MinMaxScalerParams with Writable {
def this() = this(Identifiable.randomUID("minMaxScal"))
@@ -115,6 +118,19 @@ class MinMaxScaler(override val uid: String)
}
override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object MinMaxScaler extends Readable[MinMaxScaler] {
+
+ @Since("1.6.0")
+ override def read: Reader[MinMaxScaler] = new DefaultParamsReader
+
+ @Since("1.6.0")
+ override def load(path: String): MinMaxScaler = super.load(path)
}
/**
@@ -131,7 +147,9 @@ class MinMaxScalerModel private[ml] (
override val uid: String,
val originalMin: Vector,
val originalMax: Vector)
- extends Model[MinMaxScalerModel] with MinMaxScalerParams {
+ extends Model[MinMaxScalerModel] with MinMaxScalerParams with Writable {
+
+ import MinMaxScalerModel._
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -175,4 +193,46 @@ class MinMaxScalerModel private[ml] (
val copied = new MinMaxScalerModel(uid, originalMin, originalMax)
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("1.6.0")
+ override def write: Writer = new MinMaxScalerModelWriter(this)
+}
+
+@Since("1.6.0")
+object MinMaxScalerModel extends Readable[MinMaxScalerModel] {
+
+ private[MinMaxScalerModel]
+ class MinMaxScalerModelWriter(instance: MinMaxScalerModel) extends Writer {
+
+ private case class Data(originalMin: Vector, originalMax: Vector)
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = new Data(instance.originalMin, instance.originalMax)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class MinMaxScalerModelReader extends Reader[MinMaxScalerModel] {
+
+ private val className = "org.apache.spark.ml.feature.MinMaxScalerModel"
+
+ override def load(path: String): MinMaxScalerModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val Row(originalMin: Vector, originalMax: Vector) = sqlContext.read.parquet(dataPath)
+ .select("originalMin", "originalMax")
+ .head()
+ val model = new MinMaxScalerModel(metadata.uid, originalMin, originalMax)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+
+ @Since("1.6.0")
+ override def read: Reader[MinMaxScalerModel] = new MinMaxScalerModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): MinMaxScalerModel = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index f6d0b0c0e9..ab04e5418d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -17,11 +17,13 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
@@ -57,7 +59,7 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
*/
@Experimental
class StandardScaler(override val uid: String) extends Estimator[StandardScalerModel]
- with StandardScalerParams {
+ with StandardScalerParams with Writable {
def this() = this(Identifiable.randomUID("stdScal"))
@@ -94,6 +96,19 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
}
override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object StandardScaler extends Readable[StandardScaler] {
+
+ @Since("1.6.0")
+ override def read: Reader[StandardScaler] = new DefaultParamsReader
+
+ @Since("1.6.0")
+ override def load(path: String): StandardScaler = super.load(path)
}
/**
@@ -104,7 +119,9 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
class StandardScalerModel private[ml] (
override val uid: String,
scaler: feature.StandardScalerModel)
- extends Model[StandardScalerModel] with StandardScalerParams {
+ extends Model[StandardScalerModel] with StandardScalerParams with Writable {
+
+ import StandardScalerModel._
/** Standard deviation of the StandardScalerModel */
val std: Vector = scaler.std
@@ -112,6 +129,14 @@ class StandardScalerModel private[ml] (
/** Mean of the StandardScalerModel */
val mean: Vector = scaler.mean
+ /** Whether to scale to unit standard deviation. */
+ @Since("1.6.0")
+ def getWithStd: Boolean = scaler.withStd
+
+ /** Whether to center data with mean. */
+ @Since("1.6.0")
+ def getWithMean: Boolean = scaler.withMean
+
/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
@@ -138,4 +163,49 @@ class StandardScalerModel private[ml] (
val copied = new StandardScalerModel(uid, scaler)
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("1.6.0")
+ override def write: Writer = new StandardScalerModelWriter(this)
+}
+
+@Since("1.6.0")
+object StandardScalerModel extends Readable[StandardScalerModel] {
+
+ private[StandardScalerModel]
+ class StandardScalerModelWriter(instance: StandardScalerModel) extends Writer {
+
+ private case class Data(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean)
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = Data(instance.std, instance.mean, instance.getWithStd, instance.getWithMean)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class StandardScalerModelReader extends Reader[StandardScalerModel] {
+
+ private val className = "org.apache.spark.ml.feature.StandardScalerModel"
+
+ override def load(path: String): StandardScalerModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val Row(std: Vector, mean: Vector, withStd: Boolean, withMean: Boolean) =
+ sqlContext.read.parquet(dataPath)
+ .select("std", "mean", "withStd", "withMean")
+ .head()
+ // This is very likely to change in the future because withStd and withMean should be params.
+ val oldModel = new feature.StandardScalerModel(std, mean, withStd, withMean)
+ val model = new StandardScalerModel(metadata.uid, oldModel)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+
+ @Since("1.6.0")
+ override def read: Reader[StandardScalerModel] = new StandardScalerModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): StandardScalerModel = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index f782a272d1..f16f6afc00 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -17,13 +17,14 @@
package org.apache.spark.ml.feature
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.SparkException
-import org.apache.spark.annotation.{Since, Experimental}
-import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.{Estimator, Model, Transformer}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
@@ -64,7 +65,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
*/
@Experimental
class StringIndexer(override val uid: String) extends Estimator[StringIndexerModel]
- with StringIndexerBase {
+ with StringIndexerBase with Writable {
def this() = this(Identifiable.randomUID("strIdx"))
@@ -92,6 +93,19 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
}
override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object StringIndexer extends Readable[StringIndexer] {
+
+ @Since("1.6.0")
+ override def read: Reader[StringIndexer] = new DefaultParamsReader
+
+ @Since("1.6.0")
+ override def load(path: String): StringIndexer = super.load(path)
}
/**
@@ -107,7 +121,10 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
@Experimental
class StringIndexerModel (
override val uid: String,
- val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
+ val labels: Array[String])
+ extends Model[StringIndexerModel] with StringIndexerBase with Writable {
+
+ import StringIndexerModel._
def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels)
@@ -176,6 +193,49 @@ class StringIndexerModel (
val copied = new StringIndexerModel(uid, labels)
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("1.6.0")
+ override def write: StringIndexModelWriter = new StringIndexModelWriter(this)
+}
+
+@Since("1.6.0")
+object StringIndexerModel extends Readable[StringIndexerModel] {
+
+ private[StringIndexerModel]
+ class StringIndexModelWriter(instance: StringIndexerModel) extends Writer {
+
+ private case class Data(labels: Array[String])
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val data = Data(instance.labels)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class StringIndexerModelReader extends Reader[StringIndexerModel] {
+
+ private val className = "org.apache.spark.ml.feature.StringIndexerModel"
+
+ override def load(path: String): StringIndexerModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath)
+ .select("labels")
+ .head()
+ val labels = data.getAs[Seq[String]](0).toArray
+ val model = new StringIndexerModel(metadata.uid, labels)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+
+ @Since("1.6.0")
+ override def read: Reader[StringIndexerModel] = new StringIndexerModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): StringIndexerModel = super.load(path)
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index e192fa4850..9c99990173 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -18,14 +18,17 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
-class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
test("params") {
+ ParamsSuite.checkParams(new CountVectorizer)
ParamsSuite.checkParams(new CountVectorizerModel(Array("empty")))
}
@@ -164,4 +167,23 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(features ~== expected absTol 1e-14)
}
}
+
+ test("CountVectorizer read/write") {
+ val t = new CountVectorizer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setMinDF(0.5)
+ .setMinTF(3.0)
+ .setVocabSize(10)
+ testDefaultReadWrite(t)
+ }
+
+ test("CountVectorizerModel read/write") {
+ val instance = new CountVectorizerModel("myCountVectorizerModel", Array("a", "b", "c"))
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setMinTF(3.0)
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.vocabulary === instance.vocabulary)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index 08f80af034..bc958c1585 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -19,13 +19,14 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
-class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
+class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = {
dataSet.map {
@@ -98,4 +99,20 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
}
}
+
+ test("IDF read/write") {
+ val t = new IDF()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setMinDocFreq(5)
+ testDefaultReadWrite(t)
+ }
+
+ test("IDFModel read/write") {
+ val instance = new IDFModel("myIDFModel", new OldIDFModel(Vectors.dense(1.0, 2.0)))
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.idf === instance.idf)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index c04dda41ee..09183fe65b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -18,12 +18,12 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SQLContext}
-class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("MinMaxScaler fit basic case") {
val sqlContext = new SQLContext(sc)
@@ -69,4 +69,25 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
}
+
+ test("MinMaxScaler read/write") {
+ val t = new MinMaxScaler()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setMax(1.0)
+ .setMin(-1.0)
+ testDefaultReadWrite(t)
+ }
+
+ test("MinMaxScalerModel read/write") {
+ val instance = new MinMaxScalerModel(
+ "myMinMaxScalerModel", Vectors.dense(-1.0, 0.0), Vectors.dense(1.0, 10.0))
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setMin(-1.0)
+ .setMax(1.0)
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.originalMin === instance.originalMin)
+ assert(newInstance.originalMax === instance.originalMax)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
index 879a3ae875..49a4b2efe0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
@@ -19,12 +19,16 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
-class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{
+class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
@transient var data: Array[Vector] = _
@transient var resWithStd: Array[Vector] = _
@@ -56,23 +60,29 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{
)
}
- def assertResult(dataframe: DataFrame): Unit = {
- dataframe.select("standarded_features", "expected").collect().foreach {
+ def assertResult(df: DataFrame): Unit = {
+ df.select("standardized_features", "expected").collect().foreach {
case Row(vector1: Vector, vector2: Vector) =>
assert(vector1 ~== vector2 absTol 1E-5,
"The vector value is not correct after standardization.")
}
}
+ test("params") {
+ ParamsSuite.checkParams(new StandardScaler)
+ val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0))
+ ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel))
+ }
+
test("Standardization with default parameter") {
val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected")
- val standardscaler0 = new StandardScaler()
+ val standardScaler0 = new StandardScaler()
.setInputCol("features")
- .setOutputCol("standarded_features")
+ .setOutputCol("standardized_features")
.fit(df0)
- assertResult(standardscaler0.transform(df0))
+ assertResult(standardScaler0.transform(df0))
}
test("Standardization with setter") {
@@ -80,29 +90,49 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{
val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected")
val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected")
- val standardscaler1 = new StandardScaler()
+ val standardScaler1 = new StandardScaler()
.setInputCol("features")
- .setOutputCol("standarded_features")
+ .setOutputCol("standardized_features")
.setWithMean(true)
.setWithStd(true)
.fit(df1)
- val standardscaler2 = new StandardScaler()
+ val standardScaler2 = new StandardScaler()
.setInputCol("features")
- .setOutputCol("standarded_features")
+ .setOutputCol("standardized_features")
.setWithMean(true)
.setWithStd(false)
.fit(df2)
- val standardscaler3 = new StandardScaler()
+ val standardScaler3 = new StandardScaler()
.setInputCol("features")
- .setOutputCol("standarded_features")
+ .setOutputCol("standardized_features")
.setWithMean(false)
.setWithStd(false)
.fit(df3)
- assertResult(standardscaler1.transform(df1))
- assertResult(standardscaler2.transform(df2))
- assertResult(standardscaler3.transform(df3))
+ assertResult(standardScaler1.transform(df1))
+ assertResult(standardScaler2.transform(df2))
+ assertResult(standardScaler3.transform(df3))
+ }
+
+ test("StandardScaler read/write") {
+ val t = new StandardScaler()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setWithStd(false)
+ .setWithMean(true)
+ testDefaultReadWrite(t)
+ }
+
+ test("StandardScalerModel read/write") {
+ val oldModel = new feature.StandardScalerModel(
+ Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true)
+ val instance = new StandardScalerModel("myStandardScalerModel", oldModel)
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.std === instance.std)
+ assert(newInstance.mean === instance.mean)
+ assert(newInstance.getWithStd === instance.getWithStd)
+ assert(newInstance.getWithMean === instance.getWithMean)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index be37bfb438..749bfac747 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -118,6 +118,23 @@ class StringIndexerSuite
assert(indexerModel.transform(df).eq(df))
}
+ test("StringIndexer read/write") {
+ val t = new StringIndexer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setHandleInvalid("skip")
+ testDefaultReadWrite(t)
+ }
+
+ test("StringIndexerModel read/write") {
+ val instance = new StringIndexerModel("myStringIndexerModel", Array("a", "b", "c"))
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setHandleInvalid("skip")
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.labels === instance.labels)
+ }
+
test("IndexToString params") {
val idxToStr = new IndexToString()
ParamsSuite.checkParams(idxToStr)
@@ -175,7 +192,7 @@ class StringIndexerSuite
assert(outSchema("output").dataType === StringType)
}
- test("read/write") {
+ test("IndexToString read/write") {
val t = new IndexToString()
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")