aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-11-17 12:43:56 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-17 12:43:56 -0800
commitd98d1cb000c8c4e391d73ae86efd09f15e5d165c (patch)
tree3fdd8796e394db6beaa62b14f243f5540dd49e7b
parentd9251496640a77568a1e9ed5045ce2dfba4b437b (diff)
downloadspark-d98d1cb000c8c4e391d73ae86efd09f15e5d165c.tar.gz
spark-d98d1cb000c8c4e391d73ae86efd09f15e5d165c.tar.bz2
spark-d98d1cb000c8c4e391d73ae86efd09f15e5d165c.zip
[SPARK-11769][ML] Add save, load to all basic Transformers
This excludes Estimators and ones which include Vector and other non-basic types for Params or data. This adds: * Bucketizer * DCT * HashingTF * Interaction * NGram * Normalizer * OneHotEncoder * PolynomialExpansion * QuantileDiscretizer * RFormula * SQLTransformer * StopWordsRemover * StringIndexer * Tokenizer * VectorAssembler * VectorSlicer CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #9755 from jkbradley/transformer-io.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala29
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala27
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala35
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala22
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala25
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala12
32 files changed, 453 insertions, 84 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index e5c25574d4..e2be6547d8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.BinaryAttribute
import org.apache.spark.ml.param._
@@ -87,10 +87,16 @@ final class Binarizer(override val uid: String)
override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
+ @Since("1.6.0")
override def write: Writer = new DefaultParamsWriter(this)
}
+@Since("1.6.0")
object Binarizer extends Readable[Binarizer] {
+ @Since("1.6.0")
override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer]
+
+ @Since("1.6.0")
+ override def load(path: String): Binarizer = read.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 6fdf25b015..7095fbd70a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -20,12 +20,12 @@ package org.apache.spark.ml.feature
import java.{util => ju}
import org.apache.spark.SparkException
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Model
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
@@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
*/
@Experimental
final class Bucketizer(override val uid: String)
- extends Model[Bucketizer] with HasInputCol with HasOutputCol {
+ extends Model[Bucketizer] with HasInputCol with HasOutputCol with Writable {
def this() = this(Identifiable.randomUID("bucketizer"))
@@ -93,11 +93,15 @@ final class Bucketizer(override val uid: String)
override def copy(extra: ParamMap): Bucketizer = {
defaultCopy[Bucketizer](extra).setParent(parent)
}
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
}
-private[feature] object Bucketizer {
+object Bucketizer extends Readable[Bucketizer] {
+
/** We require splits to be of length >= 3 and to be in strictly increasing order. */
- def checkSplits(splits: Array[Double]): Boolean = {
+ private[feature] def checkSplits(splits: Array[Double]): Boolean = {
if (splits.length < 3) {
false
} else {
@@ -115,7 +119,7 @@ private[feature] object Bucketizer {
* Binary searching in several buckets to place each data point.
* @throws SparkException if a feature is < splits.head or > splits.last
*/
- def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
+ private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
if (feature == splits.last) {
splits.length - 2
} else {
@@ -134,4 +138,10 @@ private[feature] object Bucketizer {
}
}
}
+
+ @Since("1.6.0")
+ override def read: Reader[Bucketizer] = new DefaultParamsReader[Bucketizer]
+
+ @Since("1.6.0")
+ override def load(path: String): Bucketizer = read.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
index 228347635c..6ea5a61617 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala
@@ -19,10 +19,10 @@ package org.apache.spark.ml.feature
import edu.emory.mathcs.jtransforms.dct._
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.BooleanParam
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.sql.types.DataType
@@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType
*/
@Experimental
class DCT(override val uid: String)
- extends UnaryTransformer[Vector, Vector, DCT] {
+ extends UnaryTransformer[Vector, Vector, DCT] with Writable {
def this() = this(Identifiable.randomUID("dct"))
@@ -69,4 +69,17 @@ class DCT(override val uid: String)
}
override protected def outputDataType: DataType = new VectorUDT
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object DCT extends Readable[DCT] {
+
+ @Since("1.6.0")
+ override def read: Reader[DCT] = new DefaultParamsReader[DCT]
+
+ @Since("1.6.0")
+ override def load(path: String): DCT = read.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 319d23e46c..6d2ea675f5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -17,12 +17,12 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}
@@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{ArrayType, StructType}
* Maps a sequence of terms to their term frequencies using the hashing trick.
*/
@Experimental
-class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol {
+class HashingTF(override val uid: String)
+ extends Transformer with HasInputCol with HasOutputCol with Writable {
def this() = this(Identifiable.randomUID("hashingTF"))
@@ -76,4 +77,17 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w
}
override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object HashingTF extends Readable[HashingTF] {
+
+ @Since("1.6.0")
+ override def read: Reader[HashingTF] = new DefaultParamsReader[HashingTF]
+
+ @Since("1.6.0")
+ override def load(path: String): HashingTF = read.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
index 37f7862476..9df6b311cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
@@ -20,11 +20,11 @@ package org.apache.spark.ml.feature
import scala.collection.mutable.ArrayBuilder
import org.apache.spark.SparkException
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.ml.Transformer
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.sql.{DataFrame, Row}
@@ -42,24 +42,30 @@ import org.apache.spark.sql.types._
* `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal
* with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`.
*/
+@Since("1.6.0")
@Experimental
-class Interaction(override val uid: String) extends Transformer
- with HasInputCols with HasOutputCol {
+class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
+ with HasInputCols with HasOutputCol with Writable {
+ @Since("1.6.0")
def this() = this(Identifiable.randomUID("interaction"))
/** @group setParam */
+ @Since("1.6.0")
def setInputCols(values: Array[String]): this.type = set(inputCols, values)
/** @group setParam */
+ @Since("1.6.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
// optimistic schema; does not contain any ML attributes
+ @Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false))
}
+ @Since("1.6.0")
override def transform(dataset: DataFrame): DataFrame = {
validateParams()
val inputFeatures = $(inputCols).map(c => dataset.schema(c))
@@ -208,14 +214,29 @@ class Interaction(override val uid: String) extends Transformer
}
}
+ @Since("1.6.0")
override def copy(extra: ParamMap): Interaction = defaultCopy(extra)
+ @Since("1.6.0")
override def validateParams(): Unit = {
require(get(inputCols).isDefined, "Input cols must be defined first.")
require(get(outputCol).isDefined, "Output col must be defined first.")
require($(inputCols).length > 0, "Input cols must have non-zero length.")
require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.")
}
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object Interaction extends Readable[Interaction] {
+
+ @Since("1.6.0")
+ override def read: Reader[Interaction] = new DefaultParamsReader[Interaction]
+
+ @Since("1.6.0")
+ override def load(path: String): Interaction = read.load(path)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
index 8de10eb51f..4a17acd951 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
@@ -17,10 +17,10 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
/**
@@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
*/
@Experimental
class NGram(override val uid: String)
- extends UnaryTransformer[Seq[String], Seq[String], NGram] {
+ extends UnaryTransformer[Seq[String], Seq[String], NGram] with Writable {
def this() = this(Identifiable.randomUID("ngram"))
@@ -66,4 +66,17 @@ class NGram(override val uid: String)
}
override protected def outputDataType: DataType = new ArrayType(StringType, false)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object NGram extends Readable[NGram] {
+
+ @Since("1.6.0")
+ override def read: Reader[NGram] = new DefaultParamsReader[NGram]
+
+ @Since("1.6.0")
+ override def load(path: String): NGram = read.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
index 8282e5ffa1..9df6a091d5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala
@@ -17,10 +17,10 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{DoubleParam, ParamValidators}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.types.DataType
@@ -30,7 +30,8 @@ import org.apache.spark.sql.types.DataType
* Normalize a vector to have unit norm using the given p-norm.
*/
@Experimental
-class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] {
+class Normalizer(override val uid: String)
+ extends UnaryTransformer[Vector, Vector, Normalizer] with Writable {
def this() = this(Identifiable.randomUID("normalizer"))
@@ -55,4 +56,17 @@ class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vect
}
override protected def outputDataType: DataType = new VectorUDT()
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object Normalizer extends Readable[Normalizer] {
+
+ @Since("1.6.0")
+ override def read: Reader[Normalizer] = new DefaultParamsReader[Normalizer]
+
+ @Since("1.6.0")
+ override def load(path: String): Normalizer = read.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index 9c60d4084e..4e2adfaafa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -17,12 +17,12 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}
@@ -44,7 +44,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType}
*/
@Experimental
class OneHotEncoder(override val uid: String) extends Transformer
- with HasInputCol with HasOutputCol {
+ with HasInputCol with HasOutputCol with Writable {
def this() = this(Identifiable.randomUID("oneHot"))
@@ -165,4 +165,17 @@ class OneHotEncoder(override val uid: String) extends Transformer
}
override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object OneHotEncoder extends Readable[OneHotEncoder] {
+
+ @Since("1.6.0")
+ override def read: Reader[OneHotEncoder] = new DefaultParamsReader[OneHotEncoder]
+
+ @Since("1.6.0")
+ override def load(path: String): OneHotEncoder = read.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
index d85e468562..4941539832 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala
@@ -19,10 +19,10 @@ package org.apache.spark.ml.feature
import scala.collection.mutable
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.types.DataType
@@ -36,7 +36,7 @@ import org.apache.spark.sql.types.DataType
*/
@Experimental
class PolynomialExpansion(override val uid: String)
- extends UnaryTransformer[Vector, Vector, PolynomialExpansion] {
+ extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with Writable {
def this() = this(Identifiable.randomUID("poly"))
@@ -63,6 +63,9 @@ class PolynomialExpansion(override val uid: String)
override protected def outputDataType: DataType = new VectorUDT()
override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
}
/**
@@ -77,7 +80,8 @@ class PolynomialExpansion(override val uid: String)
* To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the
* current index and increment it properly for sparse input.
*/
-private[feature] object PolynomialExpansion {
+@Since("1.6.0")
+object PolynomialExpansion extends Readable[PolynomialExpansion] {
private def choose(n: Int, k: Int): Int = {
Range(n, n - k, -1).product / Range(k, 1, -1).product
@@ -169,11 +173,17 @@ private[feature] object PolynomialExpansion {
new SparseVector(polySize - 1, polyIndices.result(), polyValues.result())
}
- def expand(v: Vector, degree: Int): Vector = {
+ private[feature] def expand(v: Vector, degree: Int): Vector = {
v match {
case dv: DenseVector => expand(dv, degree)
case sv: SparseVector => expand(sv, degree)
case _ => throw new IllegalArgumentException
}
}
+
+ @Since("1.6.0")
+ override def read: Reader[PolynomialExpansion] = new DefaultParamsReader[PolynomialExpansion]
+
+ @Since("1.6.0")
+ override def load(path: String): PolynomialExpansion = read.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 46b836da9c..2da5c966d2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import scala.collection.mutable
import org.apache.spark.Logging
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml._
import org.apache.spark.ml.attribute.NominalAttribute
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
@@ -60,7 +60,7 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol w
*/
@Experimental
final class QuantileDiscretizer(override val uid: String)
- extends Estimator[Bucketizer] with QuantileDiscretizerBase {
+ extends Estimator[Bucketizer] with QuantileDiscretizerBase with Writable {
def this() = this(Identifiable.randomUID("quantileDiscretizer"))
@@ -93,13 +93,17 @@ final class QuantileDiscretizer(override val uid: String)
}
override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
}
-private[feature] object QuantileDiscretizer extends Logging {
+@Since("1.6.0")
+object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging {
/**
* Sampling from the given dataset to collect quantile statistics.
*/
- def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = {
+ private[feature] def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = {
val totalSamples = dataset.count()
require(totalSamples > 0,
"QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
@@ -111,6 +115,7 @@ private[feature] object QuantileDiscretizer extends Logging {
/**
* Compute split points with respect to the sample distribution.
*/
+ private[feature]
def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = {
val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
m + ((x, m.getOrElse(x, 0) + 1))
@@ -150,7 +155,7 @@ private[feature] object QuantileDiscretizer extends Logging {
* Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as
* needed, and adding a default split value of 0 if no good candidates are found.
*/
- def getSplits(candidates: Array[Double]): Array[Double] = {
+ private[feature] def getSplits(candidates: Array[Double]): Array[Double] = {
val effectiveValues = if (candidates.size != 0) {
if (candidates.head == Double.NegativeInfinity
&& candidates.last == Double.PositiveInfinity) {
@@ -172,5 +177,10 @@ private[feature] object QuantileDiscretizer extends Logging {
Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity)
}
}
-}
+ @Since("1.6.0")
+ override def read: Reader[QuantileDiscretizer] = new DefaultParamsReader[QuantileDiscretizer]
+
+ @Since("1.6.0")
+ override def load(path: String): QuantileDiscretizer = read.load(path)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index 95e4305638..c115064ff3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -18,10 +18,10 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkContext
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.param.{ParamMap, Param}
import org.apache.spark.ml.Transformer
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.sql.{SQLContext, DataFrame, Row}
import org.apache.spark.sql.types.StructType
@@ -32,24 +32,30 @@ import org.apache.spark.sql.types.StructType
* where '__THIS__' represents the underlying table of the input dataset.
*/
@Experimental
-class SQLTransformer (override val uid: String) extends Transformer {
+@Since("1.6.0")
+class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer with Writable {
+ @Since("1.6.0")
def this() = this(Identifiable.randomUID("sql"))
/**
* SQL statement parameter. The statement is provided in string form.
* @group param
*/
+ @Since("1.6.0")
final val statement: Param[String] = new Param[String](this, "statement", "SQL statement")
/** @group setParam */
+ @Since("1.6.0")
def setStatement(value: String): this.type = set(statement, value)
/** @group getParam */
+ @Since("1.6.0")
def getStatement: String = $(statement)
private val tableIdentifier: String = "__THIS__"
+ @Since("1.6.0")
override def transform(dataset: DataFrame): DataFrame = {
val tableName = Identifiable.randomUID(uid)
dataset.registerTempTable(tableName)
@@ -58,6 +64,7 @@ class SQLTransformer (override val uid: String) extends Transformer {
outputDF
}
+ @Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
@@ -68,5 +75,19 @@ class SQLTransformer (override val uid: String) extends Transformer {
outputSchema
}
+ @Since("1.6.0")
override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object SQLTransformer extends Readable[SQLTransformer] {
+
+ @Since("1.6.0")
+ override def read: Reader[SQLTransformer] = new DefaultParamsReader[SQLTransformer]
+
+ @Since("1.6.0")
+ override def load(path: String): SQLTransformer = read.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 2a79582625..f1146988dc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -17,11 +17,11 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType}
@@ -86,7 +86,7 @@ private[spark] object StopWords {
*/
@Experimental
class StopWordsRemover(override val uid: String)
- extends Transformer with HasInputCol with HasOutputCol {
+ extends Transformer with HasInputCol with HasOutputCol with Writable {
def this() = this(Identifiable.randomUID("stopWords"))
@@ -154,4 +154,17 @@ class StopWordsRemover(override val uid: String)
}
override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object StopWordsRemover extends Readable[StopWordsRemover] {
+
+ @Since("1.6.0")
+ override def read: Reader[StopWordsRemover] = new DefaultParamsReader[StopWordsRemover]
+
+ @Since("1.6.0")
+ override def load(path: String): StopWordsRemover = read.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 486274cd75..f782a272d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -18,13 +18,13 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkException
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.Transformer
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
@@ -188,9 +188,8 @@ class StringIndexerModel (
* @see [[StringIndexer]] for converting strings into indices
*/
@Experimental
-class IndexToString private[ml] (
- override val uid: String) extends Transformer
- with HasInputCol with HasOutputCol {
+class IndexToString private[ml] (override val uid: String)
+ extends Transformer with HasInputCol with HasOutputCol with Writable {
def this() =
this(Identifiable.randomUID("idxToStr"))
@@ -257,4 +256,17 @@ class IndexToString private[ml] (
override def copy(extra: ParamMap): IndexToString = {
defaultCopy(extra)
}
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object IndexToString extends Readable[IndexToString] {
+
+ @Since("1.6.0")
+ override def read: Reader[IndexToString] = new DefaultParamsReader[IndexToString]
+
+ @Since("1.6.0")
+ override def load(path: String): IndexToString = read.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 1b82b40caa..0e4445d1e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -17,10 +17,10 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
/**
@@ -30,7 +30,8 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
* @see [[RegexTokenizer]]
*/
@Experimental
-class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] {
+class Tokenizer(override val uid: String)
+ extends UnaryTransformer[String, Seq[String], Tokenizer] with Writable {
def this() = this(Identifiable.randomUID("tok"))
@@ -45,6 +46,19 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
override protected def outputDataType: DataType = new ArrayType(StringType, true)
override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object Tokenizer extends Readable[Tokenizer] {
+
+ @Since("1.6.0")
+ override def read: Reader[Tokenizer] = new DefaultParamsReader[Tokenizer]
+
+ @Since("1.6.0")
+ override def load(path: String): Tokenizer = read.load(path)
}
/**
@@ -56,7 +70,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
*/
@Experimental
class RegexTokenizer(override val uid: String)
- extends UnaryTransformer[String, Seq[String], RegexTokenizer] {
+ extends UnaryTransformer[String, Seq[String], RegexTokenizer] with Writable {
def this() = this(Identifiable.randomUID("regexTok"))
@@ -131,4 +145,17 @@ class RegexTokenizer(override val uid: String)
override protected def outputDataType: DataType = new ArrayType(StringType, true)
override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
+}
+
+@Since("1.6.0")
+object RegexTokenizer extends Readable[RegexTokenizer] {
+
+ @Since("1.6.0")
+ override def read: Reader[RegexTokenizer] = new DefaultParamsReader[RegexTokenizer]
+
+ @Since("1.6.0")
+ override def load(path: String): RegexTokenizer = read.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 086917fa68..7e54205292 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -20,12 +20,12 @@ package org.apache.spark.ml.feature
import scala.collection.mutable.ArrayBuilder
import org.apache.spark.SparkException
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
@@ -37,7 +37,7 @@ import org.apache.spark.sql.types._
*/
@Experimental
class VectorAssembler(override val uid: String)
- extends Transformer with HasInputCols with HasOutputCol {
+ extends Transformer with HasInputCols with HasOutputCol with Writable {
def this() = this(Identifiable.randomUID("vecAssembler"))
@@ -120,9 +120,19 @@ class VectorAssembler(override val uid: String)
}
override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
}
-private object VectorAssembler {
+@Since("1.6.0")
+object VectorAssembler extends Readable[VectorAssembler] {
+
+ @Since("1.6.0")
+ override def read: Reader[VectorAssembler] = new DefaultParamsReader[VectorAssembler]
+
+ @Since("1.6.0")
+ override def load(path: String): VectorAssembler = read.load(path)
private[feature] def assemble(vv: Any*): Vector = {
val indices = ArrayBuilder.make[Int]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index fb3387d4aa..911582b55b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -17,12 +17,12 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam}
-import org.apache.spark.ml.util.{Identifiable, MetadataUtils, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
@@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType
*/
@Experimental
final class VectorSlicer(override val uid: String)
- extends Transformer with HasInputCol with HasOutputCol {
+ extends Transformer with HasInputCol with HasOutputCol with Writable {
def this() = this(Identifiable.randomUID("vectorSlicer"))
@@ -151,12 +151,16 @@ final class VectorSlicer(override val uid: String)
}
override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
}
-private[feature] object VectorSlicer {
+@Since("1.6.0")
+object VectorSlicer extends Readable[VectorSlicer] {
/** Return true if given feature indices are valid */
- def validIndices(indices: Array[Int]): Boolean = {
+ private[feature] def validIndices(indices: Array[Int]): Boolean = {
if (indices.isEmpty) {
true
} else {
@@ -165,7 +169,13 @@ private[feature] object VectorSlicer {
}
/** Return true if given feature names are valid */
- def validNames(names: Array[String]): Boolean = {
+ private[feature] def validNames(names: Array[String]): Boolean = {
names.forall(_.nonEmpty) && names.length == names.distinct.length
}
+
+ @Since("1.6.0")
+ override def read: Reader[VectorSlicer] = new DefaultParamsReader[VectorSlicer]
+
+ @Since("1.6.0")
+ override def load(path: String): VectorSlicer = read.load(path)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 9dfa1439cc..6d2d8fe714 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -69,10 +69,10 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
test("read/write") {
- val binarizer = new Binarizer()
- .setInputCol("feature")
- .setOutputCol("binarized_feature")
+ val t = new Binarizer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
.setThreshold(0.1)
- testDefaultReadWrite(binarizer)
+ testDefaultReadWrite(t)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 0eba34fda6..9ea7d43176 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -21,13 +21,13 @@ import scala.util.Random
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new Bucketizer)
@@ -112,6 +112,14 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext {
val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x)))
assert(bsResult ~== lsResult absTol 1e-5)
}
+
+ test("read/write") {
+ val t = new Bucketizer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setSplits(Array(0.1, 0.8, 0.9))
+ testDefaultReadWrite(t)
+ }
}
private object BucketizerSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
index 37ed2367c3..0f2aafebaf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
@@ -22,6 +22,7 @@ import scala.beans.BeanInfo
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@@ -29,7 +30,7 @@ import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class DCTTestData(vec: Vector, wantedVec: Vector)
-class DCTSuite extends SparkFunSuite with MLlibTestSparkContext {
+class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("forward transform of discrete cosine matches jTransforms result") {
val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray)
@@ -45,6 +46,14 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext {
testDCT(data, inverse)
}
+ test("read/write") {
+ val t = new DCT()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setInverse(true)
+ testDefaultReadWrite(t)
+ }
+
private def testDCT(data: Vector, inverse: Boolean): Unit = {
val expectedResultBuffer = data.toArray.clone()
if (inverse) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 4157b84b29..0dcd0f4946 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
+class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new HashingTF)
@@ -50,4 +51,12 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
assert(features ~== expected absTol 1e-14)
}
+
+ test("read/write") {
+ val t = new HashingTF()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setNumFeatures(10)
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala
index 2beb62ca08..932d331b47 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import scala.collection.mutable.ArrayBuilder
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
@@ -26,7 +27,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.functions.col
-class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new Interaction())
}
@@ -162,4 +163,11 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext {
new NumericAttribute(Some("a_2:b_1:c"), Some(9))))
assert(attrs === expectedAttrs)
}
+
+ test("read/write") {
+ val t = new Interaction()
+ .setInputCols(Array("myInputCol", "myInputCol2"))
+ .setOutputCol("myOutputCol")
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
index ab97e3dbc6..58fda29aa1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
@@ -20,13 +20,14 @@ package org.apache.spark.ml.feature
import scala.beans.BeanInfo
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
-class NGramSuite extends SparkFunSuite with MLlibTestSparkContext {
+class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import org.apache.spark.ml.feature.NGramSuite._
test("default behavior yields bigram features") {
@@ -79,6 +80,14 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext {
)))
testNGram(nGram, dataset)
}
+
+ test("read/write") {
+ val t = new NGram()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setN(3)
+ testDefaultReadWrite(t)
+ }
}
object NGramSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
index 9f03470b7f..de3d438ce8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
@@ -18,13 +18,14 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var data: Array[Vector] = _
@transient var dataFrame: DataFrame = _
@@ -104,6 +105,14 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext {
assertValues(result, l1Normalized)
}
+
+ test("read/write") {
+ val t = new Normalizer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setP(3.0)
+ testDefaultReadWrite(t)
+ }
}
private object NormalizerSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 321eeb8439..76d12050f9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -20,12 +20,14 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
-class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
+class OneHotEncoderSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
def stringIndexed(): DataFrame = {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
@@ -101,4 +103,12 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
}
+
+ test("read/write") {
+ val t = new OneHotEncoder()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setDropLast(false)
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index 29eebd8960..70892dc571 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -21,12 +21,14 @@ import org.apache.spark.ml.param.ParamsSuite
import org.scalatest.exceptions.TestFailedException
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
-class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class PolynomialExpansionSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new PolynomialExpansion)
@@ -98,5 +100,13 @@ class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext
throw new TestFailedException("Unmatched data types after polynomial expansion", 0)
}
}
+
+ test("read/write") {
+ val t = new PolynomialExpansion()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setDegree(3)
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index b2bdd8935f..3a4f6d235a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -18,11 +18,14 @@
package org.apache.spark.ml.feature
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkContext, SparkFunSuite}
-class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class QuantileDiscretizerSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
import org.apache.spark.ml.feature.QuantileDiscretizerSuite._
test("Test quantile discretizer") {
@@ -67,6 +70,14 @@ class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext
assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.")
}
}
+
+ test("read/write") {
+ val t = new QuantileDiscretizer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setNumBuckets(6)
+ testDefaultReadWrite(t)
+ }
}
private object QuantileDiscretizerSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
index d19052881a..553e0b8702 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
@@ -19,9 +19,11 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class SQLTransformerSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new SQLTransformer())
@@ -41,4 +43,10 @@ class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(resultSchema == expected.schema)
assert(result.collect().toSeq == expected.collect().toSeq)
}
+
+ test("read/write") {
+ val t = new SQLTransformer()
+ .setStatement("select * from __THIS__")
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index e0d433f566..fb217e0c1d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@@ -32,7 +33,9 @@ object StopWordsRemoverSuite extends SparkFunSuite {
}
}
-class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
+class StopWordsRemoverSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
import StopWordsRemoverSuite._
test("StopWordsRemover default") {
@@ -77,4 +80,13 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
testStopWordsRemover(remover, dataSet)
}
+
+ test("read/write") {
+ val t = new StopWordsRemover()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setStopWords(Array("the", "a"))
+ .setCaseSensitive(true)
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index ddcdb5f421..be37bfb438 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -21,12 +21,13 @@ import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleTy
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
-class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class StringIndexerSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new StringIndexer)
@@ -173,4 +174,12 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
val outSchema = idxToStr.transformSchema(inSchema)
assert(outSchema("output").dataType === StringType)
}
+
+ test("read/write") {
+ val t = new IndexToString()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setLabels(Array("a", "b", "c"))
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index a02992a240..36e8e5d868 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -21,20 +21,30 @@ import scala.beans.BeanInfo
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
-class TokenizerSuite extends SparkFunSuite {
+class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new Tokenizer)
}
+
+ test("read/write") {
+ val t = new Tokenizer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ testDefaultReadWrite(t)
+ }
}
-class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RegexTokenizerSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
import org.apache.spark.ml.feature.RegexTokenizerSuite._
test("params") {
@@ -81,6 +91,17 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
))
testRegexTokenizer(tokenizer, dataset)
}
+
+ test("read/write") {
+ val t = new RegexTokenizer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setMinTokenLength(2)
+ .setGaps(false)
+ .setPattern("hi")
+ .setToLowercase(false)
+ testDefaultReadWrite(t)
+ }
}
object RegexTokenizerSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index bb4d5b983e..fb21ab6b9b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.param.ParamsSuite
@@ -25,7 +26,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
-class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class VectorAssemblerSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new VectorAssembler)
@@ -101,4 +103,11 @@ class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5))
assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6))
}
+
+ test("read/write") {
+ val t = new VectorAssembler()
+ .setInputCols(Array("myInputCol", "myInputCol2"))
+ .setOutputCol("myOutputCol")
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
index a6c2fba836..74706a23e0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
val slicer = new VectorSlicer
@@ -106,4 +107,13 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext {
vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4"))
validateResults(vectorSlicer.transform(df))
}
+
+ test("read/write") {
+ val t = new VectorSlicer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setIndices(Array(1, 3))
+ .setNames(Array("a", "d"))
+ testDefaultReadWrite(t)
+ }
}