diff options
Diffstat (limited to 'mllib')
32 files changed, 453 insertions, 84 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index e5c25574d4..e2be6547d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ @@ -87,10 +87,16 @@ final class Binarizer(override val uid: String) override def copy(extra: ParamMap): Binarizer = defaultCopy(extra) + @Since("1.6.0") override def write: Writer = new DefaultParamsWriter(this) } +@Since("1.6.0") object Binarizer extends Readable[Binarizer] { + @Since("1.6.0") override def read: Reader[Binarizer] = new DefaultParamsReader[Binarizer] + + @Since("1.6.0") + override def load(path: String): Binarizer = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 6fdf25b015..7095fbd70a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -20,12 +20,12 @@ package org.apache.spark.ml.feature import java.{util => ju} import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} */ @Experimental final class Bucketizer(override val uid: String) - extends Model[Bucketizer] with HasInputCol with HasOutputCol { + extends Model[Bucketizer] with HasInputCol with HasOutputCol with Writable { def this() = this(Identifiable.randomUID("bucketizer")) @@ -93,11 +93,15 @@ final class Bucketizer(override val uid: String) override def copy(extra: ParamMap): Bucketizer = { defaultCopy[Bucketizer](extra).setParent(parent) } + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } -private[feature] object Bucketizer { +object Bucketizer extends Readable[Bucketizer] { + /** We require splits to be of length >= 3 and to be in strictly increasing order. */ - def checkSplits(splits: Array[Double]): Boolean = { + private[feature] def checkSplits(splits: Array[Double]): Boolean = { if (splits.length < 3) { false } else { @@ -115,7 +119,7 @@ private[feature] object Bucketizer { * Binary searching in several buckets to place each data point. * @throws SparkException if a feature is < splits.head or > splits.last */ - def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { + private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { if (feature == splits.last) { splits.length - 2 } else { @@ -134,4 +138,10 @@ private[feature] object Bucketizer { } } } + + @Since("1.6.0") + override def read: Reader[Bucketizer] = new DefaultParamsReader[Bucketizer] + + @Since("1.6.0") + override def load(path: String): Bucketizer = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 228347635c..6ea5a61617 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -19,10 +19,10 @@ package org.apache.spark.ml.feature import edu.emory.mathcs.jtransforms.dct._ -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.BooleanParam -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.types.DataType @@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class DCT(override val uid: String) - extends UnaryTransformer[Vector, Vector, DCT] { + extends UnaryTransformer[Vector, Vector, DCT] with Writable { def this() = this(Identifiable.randomUID("dct")) @@ -69,4 +69,17 @@ class DCT(override val uid: String) } override protected def outputDataType: DataType = new VectorUDT + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object DCT extends Readable[DCT] { + + @Since("1.6.0") + override def read: Reader[DCT] = new DefaultParamsReader[DCT] + + @Since("1.6.0") + override def load(path: String): DCT = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 319d23e46c..6d2ea675f5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -17,12 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.{col, udf} @@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{ArrayType, StructType} * Maps a sequence of terms to their term frequencies using the hashing trick. */ @Experimental -class HashingTF(override val uid: String) extends Transformer with HasInputCol with HasOutputCol { +class HashingTF(override val uid: String) + extends Transformer with HasInputCol with HasOutputCol with Writable { def this() = this(Identifiable.randomUID("hashingTF")) @@ -76,4 +77,17 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w } override def copy(extra: ParamMap): HashingTF = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object HashingTF extends Readable[HashingTF] { + + @Since("1.6.0") + override def read: Reader[HashingTF] = new DefaultParamsReader[HashingTF] + + @Since("1.6.0") + override def load(path: String): HashingTF = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 37f7862476..9df6b311cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -20,11 +20,11 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.ml.Transformer import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.{DataFrame, Row} @@ -42,24 +42,30 @@ import org.apache.spark.sql.types._ * `Vector(6, 8)` if all input features were numeric. If the first feature was instead nominal * with four categories, the output would then be `Vector(0, 0, 0, 0, 3, 4, 0, 0)`. */ +@Since("1.6.0") @Experimental -class Interaction(override val uid: String) extends Transformer - with HasInputCols with HasOutputCol { +class Interaction @Since("1.6.0") (override val uid: String) extends Transformer + with HasInputCols with HasOutputCol with Writable { + @Since("1.6.0") def this() = this(Identifiable.randomUID("interaction")) /** @group setParam */ + @Since("1.6.0") def setInputCols(values: Array[String]): this.type = set(inputCols, values) /** @group setParam */ + @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) // optimistic schema; does not contain any ML attributes + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { validateParams() StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) } + @Since("1.6.0") override def transform(dataset: DataFrame): DataFrame = { validateParams() val inputFeatures = $(inputCols).map(c => dataset.schema(c)) @@ -208,14 +214,29 @@ class Interaction(override val uid: String) extends Transformer } } + @Since("1.6.0") override def copy(extra: ParamMap): Interaction = defaultCopy(extra) + @Since("1.6.0") override def validateParams(): Unit = { require(get(inputCols).isDefined, "Input cols must be defined first.") require(get(outputCol).isDefined, "Output col must be defined first.") require($(inputCols).length > 0, "Input cols must have non-zero length.") require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.") } + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object Interaction extends Readable[Interaction] { + + @Since("1.6.0") + override def read: Reader[Interaction] = new DefaultParamsReader[Interaction] + + @Since("1.6.0") + override def load(path: String): Interaction = read.load(path) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 8de10eb51f..4a17acd951 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} */ @Experimental class NGram(override val uid: String) - extends UnaryTransformer[Seq[String], Seq[String], NGram] { + extends UnaryTransformer[Seq[String], Seq[String], NGram] with Writable { def this() = this(Identifiable.randomUID("ngram")) @@ -66,4 +66,17 @@ class NGram(override val uid: String) } override protected def outputDataType: DataType = new ArrayType(StringType, false) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object NGram extends Readable[NGram] { + + @Since("1.6.0") + override def read: Reader[NGram] = new DefaultParamsReader[NGram] + + @Since("1.6.0") + override def load(path: String): NGram = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index 8282e5ffa1..9df6a091d5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{DoubleParam, ParamValidators} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.types.DataType @@ -30,7 +30,8 @@ import org.apache.spark.sql.types.DataType * Normalize a vector to have unit norm using the given p-norm. */ @Experimental -class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vector, Normalizer] { +class Normalizer(override val uid: String) + extends UnaryTransformer[Vector, Vector, Normalizer] with Writable { def this() = this(Identifiable.randomUID("normalizer")) @@ -55,4 +56,17 @@ class Normalizer(override val uid: String) extends UnaryTransformer[Vector, Vect } override protected def outputDataType: DataType = new VectorUDT() + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object Normalizer extends Readable[Normalizer] { + + @Since("1.6.0") + override def read: Reader[Normalizer] = new DefaultParamsReader[Normalizer] + + @Since("1.6.0") + override def load(path: String): Normalizer = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 9c60d4084e..4e2adfaafa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,12 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.{col, udf} @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.{DoubleType, StructType} */ @Experimental class OneHotEncoder(override val uid: String) extends Transformer - with HasInputCol with HasOutputCol { + with HasInputCol with HasOutputCol with Writable { def this() = this(Identifiable.randomUID("oneHot")) @@ -165,4 +165,17 @@ class OneHotEncoder(override val uid: String) extends Transformer } override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object OneHotEncoder extends Readable[OneHotEncoder] { + + @Since("1.6.0") + override def read: Reader[OneHotEncoder] = new DefaultParamsReader[OneHotEncoder] + + @Since("1.6.0") + override def load(path: String): OneHotEncoder = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index d85e468562..4941539832 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,10 +19,10 @@ package org.apache.spark.ml.feature import scala.collection.mutable -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.DataType */ @Experimental class PolynomialExpansion(override val uid: String) - extends UnaryTransformer[Vector, Vector, PolynomialExpansion] { + extends UnaryTransformer[Vector, Vector, PolynomialExpansion] with Writable { def this() = this(Identifiable.randomUID("poly")) @@ -63,6 +63,9 @@ class PolynomialExpansion(override val uid: String) override protected def outputDataType: DataType = new VectorUDT() override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } /** @@ -77,7 +80,8 @@ class PolynomialExpansion(override val uid: String) * To handle sparsity, if c is zero, we can skip all monomials that contain it. We remember the * current index and increment it properly for sparse input. */ -private[feature] object PolynomialExpansion { +@Since("1.6.0") +object PolynomialExpansion extends Readable[PolynomialExpansion] { private def choose(n: Int, k: Int): Int = { Range(n, n - k, -1).product / Range(k, 1, -1).product @@ -169,11 +173,17 @@ private[feature] object PolynomialExpansion { new SparseVector(polySize - 1, polyIndices.result(), polyValues.result()) } - def expand(v: Vector, degree: Int): Vector = { + private[feature] def expand(v: Vector, degree: Int): Vector = { v match { case dv: DenseVector => expand(dv, degree) case sv: SparseVector => expand(sv, degree) case _ => throw new IllegalArgumentException } } + + @Since("1.6.0") + override def read: Reader[PolynomialExpansion] = new DefaultParamsReader[PolynomialExpansion] + + @Since("1.6.0") + override def load(path: String): PolynomialExpansion = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 46b836da9c..2da5c966d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import scala.collection.mutable import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} @@ -60,7 +60,7 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol w */ @Experimental final class QuantileDiscretizer(override val uid: String) - extends Estimator[Bucketizer] with QuantileDiscretizerBase { + extends Estimator[Bucketizer] with QuantileDiscretizerBase with Writable { def this() = this(Identifiable.randomUID("quantileDiscretizer")) @@ -93,13 +93,17 @@ final class QuantileDiscretizer(override val uid: String) } override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } -private[feature] object QuantileDiscretizer extends Logging { +@Since("1.6.0") +object QuantileDiscretizer extends Readable[QuantileDiscretizer] with Logging { /** * Sampling from the given dataset to collect quantile statistics. */ - def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = { + private[feature] def getSampledInput(dataset: DataFrame, numBins: Int): Array[Row] = { val totalSamples = dataset.count() require(totalSamples > 0, "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") @@ -111,6 +115,7 @@ private[feature] object QuantileDiscretizer extends Logging { /** * Compute split points with respect to the sample distribution. */ + private[feature] def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = { val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) => m + ((x, m.getOrElse(x, 0) + 1)) @@ -150,7 +155,7 @@ private[feature] object QuantileDiscretizer extends Logging { * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as * needed, and adding a default split value of 0 if no good candidates are found. */ - def getSplits(candidates: Array[Double]): Array[Double] = { + private[feature] def getSplits(candidates: Array[Double]): Array[Double] = { val effectiveValues = if (candidates.size != 0) { if (candidates.head == Double.NegativeInfinity && candidates.last == Double.PositiveInfinity) { @@ -172,5 +177,10 @@ private[feature] object QuantileDiscretizer extends Logging { Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity) } } -} + @Since("1.6.0") + override def read: Reader[QuantileDiscretizer] = new DefaultParamsReader[QuantileDiscretizer] + + @Since("1.6.0") + override def load(path: String): QuantileDiscretizer = read.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 95e4305638..c115064ff3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -18,10 +18,10 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkContext -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.param.{ParamMap, Param} import org.apache.spark.ml.Transformer -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.{SQLContext, DataFrame, Row} import org.apache.spark.sql.types.StructType @@ -32,24 +32,30 @@ import org.apache.spark.sql.types.StructType * where '__THIS__' represents the underlying table of the input dataset. */ @Experimental -class SQLTransformer (override val uid: String) extends Transformer { +@Since("1.6.0") +class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transformer with Writable { + @Since("1.6.0") def this() = this(Identifiable.randomUID("sql")) /** * SQL statement parameter. The statement is provided in string form. * @group param */ + @Since("1.6.0") final val statement: Param[String] = new Param[String](this, "statement", "SQL statement") /** @group setParam */ + @Since("1.6.0") def setStatement(value: String): this.type = set(statement, value) /** @group getParam */ + @Since("1.6.0") def getStatement: String = $(statement) private val tableIdentifier: String = "__THIS__" + @Since("1.6.0") override def transform(dataset: DataFrame): DataFrame = { val tableName = Identifiable.randomUID(uid) dataset.registerTempTable(tableName) @@ -58,6 +64,7 @@ class SQLTransformer (override val uid: String) extends Transformer { outputDF } + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { val sc = SparkContext.getOrCreate() val sqlContext = SQLContext.getOrCreate(sc) @@ -68,5 +75,19 @@ class SQLTransformer (override val uid: String) extends Transformer { outputSchema } + @Since("1.6.0") override def copy(extra: ParamMap): SQLTransformer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object SQLTransformer extends Readable[SQLTransformer] { + + @Since("1.6.0") + override def read: Reader[SQLTransformer] = new DefaultParamsReader[SQLTransformer] + + @Since("1.6.0") + override def load(path: String): SQLTransformer = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 2a79582625..f1146988dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} @@ -86,7 +86,7 @@ private[spark] object StopWords { */ @Experimental class StopWordsRemover(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with Writable { def this() = this(Identifiable.randomUID("stopWords")) @@ -154,4 +154,17 @@ class StopWordsRemover(override val uid: String) } override def copy(extra: ParamMap): StopWordsRemover = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object StopWordsRemover extends Readable[StopWordsRemover] { + + @Since("1.6.0") + override def read: Reader[StopWordsRemover] = new DefaultParamsReader[StopWordsRemover] + + @Since("1.6.0") + override def load(path: String): StopWordsRemover = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 486274cd75..f782a272d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -18,13 +18,13 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.Transformer -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -188,9 +188,8 @@ class StringIndexerModel ( * @see [[StringIndexer]] for converting strings into indices */ @Experimental -class IndexToString private[ml] ( - override val uid: String) extends Transformer - with HasInputCol with HasOutputCol { +class IndexToString private[ml] (override val uid: String) + extends Transformer with HasInputCol with HasOutputCol with Writable { def this() = this(Identifiable.randomUID("idxToStr")) @@ -257,4 +256,17 @@ class IndexToString private[ml] ( override def copy(extra: ParamMap): IndexToString = { defaultCopy(extra) } + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object IndexToString extends Readable[IndexToString] { + + @Since("1.6.0") + override def read: Reader[IndexToString] = new DefaultParamsReader[IndexToString] + + @Since("1.6.0") + override def load(path: String): IndexToString = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 1b82b40caa..0e4445d1e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** @@ -30,7 +30,8 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType} * @see [[RegexTokenizer]] */ @Experimental -class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[String], Tokenizer] { +class Tokenizer(override val uid: String) + extends UnaryTransformer[String, Seq[String], Tokenizer] with Writable { def this() = this(Identifiable.randomUID("tok")) @@ -45,6 +46,19 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object Tokenizer extends Readable[Tokenizer] { + + @Since("1.6.0") + override def read: Reader[Tokenizer] = new DefaultParamsReader[Tokenizer] + + @Since("1.6.0") + override def load(path: String): Tokenizer = read.load(path) } /** @@ -56,7 +70,7 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S */ @Experimental class RegexTokenizer(override val uid: String) - extends UnaryTransformer[String, Seq[String], RegexTokenizer] { + extends UnaryTransformer[String, Seq[String], RegexTokenizer] with Writable { def this() = this(Identifiable.randomUID("regexTok")) @@ -131,4 +145,17 @@ class RegexTokenizer(override val uid: String) override protected def outputDataType: DataType = new ArrayType(StringType, true) override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) +} + +@Since("1.6.0") +object RegexTokenizer extends Readable[RegexTokenizer] { + + @Since("1.6.0") + override def read: Reader[RegexTokenizer] = new DefaultParamsReader[RegexTokenizer] + + @Since("1.6.0") + override def load(path: String): RegexTokenizer = read.load(path) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 086917fa68..7e54205292 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -20,12 +20,12 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ @@ -37,7 +37,7 @@ import org.apache.spark.sql.types._ */ @Experimental class VectorAssembler(override val uid: String) - extends Transformer with HasInputCols with HasOutputCol { + extends Transformer with HasInputCols with HasOutputCol with Writable { def this() = this(Identifiable.randomUID("vecAssembler")) @@ -120,9 +120,19 @@ class VectorAssembler(override val uid: String) } override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } -private object VectorAssembler { +@Since("1.6.0") +object VectorAssembler extends Readable[VectorAssembler] { + + @Since("1.6.0") + override def read: Reader[VectorAssembler] = new DefaultParamsReader[VectorAssembler] + + @Since("1.6.0") + override def load(path: String): VectorAssembler = read.load(path) private[feature] def assemble(vv: Any*): Vector = { val indices = ArrayBuilder.make[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index fb3387d4aa..911582b55b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -17,12 +17,12 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{Since, Experimental} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam} -import org.apache.spark.ml.util.{Identifiable, MetadataUtils, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ @@ -42,7 +42,7 @@ import org.apache.spark.sql.types.StructType */ @Experimental final class VectorSlicer(override val uid: String) - extends Transformer with HasInputCol with HasOutputCol { + extends Transformer with HasInputCol with HasOutputCol with Writable { def this() = this(Identifiable.randomUID("vectorSlicer")) @@ -151,12 +151,16 @@ final class VectorSlicer(override val uid: String) } override def copy(extra: ParamMap): VectorSlicer = defaultCopy(extra) + + @Since("1.6.0") + override def write: Writer = new DefaultParamsWriter(this) } -private[feature] object VectorSlicer { +@Since("1.6.0") +object VectorSlicer extends Readable[VectorSlicer] { /** Return true if given feature indices are valid */ - def validIndices(indices: Array[Int]): Boolean = { + private[feature] def validIndices(indices: Array[Int]): Boolean = { if (indices.isEmpty) { true } else { @@ -165,7 +169,13 @@ private[feature] object VectorSlicer { } /** Return true if given feature names are valid */ - def validNames(names: Array[String]): Boolean = { + private[feature] def validNames(names: Array[String]): Boolean = { names.forall(_.nonEmpty) && names.length == names.distinct.length } + + @Since("1.6.0") + override def read: Reader[VectorSlicer] = new DefaultParamsReader[VectorSlicer] + + @Since("1.6.0") + override def load(path: String): VectorSlicer = read.load(path) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala index 9dfa1439cc..6d2d8fe714 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala @@ -69,10 +69,10 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } test("read/write") { - val binarizer = new Binarizer() - .setInputCol("feature") - .setOutputCol("binarized_feature") + val t = new Binarizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") .setThreshold(0.1) - testDefaultReadWrite(binarizer) + testDefaultReadWrite(t) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 0eba34fda6..9ea7d43176 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -21,13 +21,13 @@ import scala.util.Random import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} -class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Bucketizer) @@ -112,6 +112,14 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext { val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x))) assert(bsResult ~== lsResult absTol 1e-5) } + + test("read/write") { + val t = new Bucketizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setSplits(Array(0.1, 0.8, 0.9)) + testDefaultReadWrite(t) + } } private object BucketizerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index 37ed2367c3..0f2aafebaf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -22,6 +22,7 @@ import scala.beans.BeanInfo import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -29,7 +30,7 @@ import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class DCTTestData(vec: Vector, wantedVec: Vector) -class DCTSuite extends SparkFunSuite with MLlibTestSparkContext { +class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("forward transform of discrete cosine matches jTransforms result") { val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) @@ -45,6 +46,14 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext { testDCT(data, inverse) } + test("read/write") { + val t = new DCT() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setInverse(true) + testDefaultReadWrite(t) + } + private def testDCT(data: Vector, inverse: Boolean): Unit = { val expectedResultBuffer = data.toArray.clone() if (inverse) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 4157b84b29..0dcd0f4946 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils -class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { +class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new HashingTF) @@ -50,4 +51,12 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) assert(features ~== expected absTol 1e-14) } + + test("read/write") { + val t = new HashingTF() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setNumFeatures(10) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 2beb62ca08..932d331b47 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite @@ -26,7 +27,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.functions.col -class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext { +class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Interaction()) } @@ -162,4 +163,11 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext { new NumericAttribute(Some("a_2:b_1:c"), Some(9)))) assert(attrs === expectedAttrs) } + + test("read/write") { + val t = new Interaction() + .setInputCols(Array("myInputCol", "myInputCol2")) + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index ab97e3dbc6..58fda29aa1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -20,13 +20,14 @@ package org.apache.spark.ml.feature import scala.beans.BeanInfo import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) -class NGramSuite extends SparkFunSuite with MLlibTestSparkContext { +class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import org.apache.spark.ml.feature.NGramSuite._ test("default behavior yields bigram features") { @@ -79,6 +80,14 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext { ))) testNGram(nGram, dataset) } + + test("read/write") { + val t = new NGram() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setN(3) + testDefaultReadWrite(t) + } } object NGramSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala index 9f03470b7f..de3d438ce8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row, SQLContext} -class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var data: Array[Vector] = _ @transient var dataFrame: DataFrame = _ @@ -104,6 +105,14 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext { assertValues(result, l1Normalized) } + + test("read/write") { + val t = new Normalizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setP(3.0) + testDefaultReadWrite(t) + } } private object NormalizerSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 321eeb8439..76d12050f9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -20,12 +20,14 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions.col -class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { +class OneHotEncoderSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { def stringIndexed(): DataFrame = { val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) @@ -101,4 +103,12 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext { assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) } + + test("read/write") { + val t = new OneHotEncoder() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setDropLast(false) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index 29eebd8960..70892dc571 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -21,12 +21,14 @@ import org.apache.spark.ml.param.ParamsSuite import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row -class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext { +class PolynomialExpansionSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new PolynomialExpansion) @@ -98,5 +100,13 @@ class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext throw new TestFailedException("Unmatched data types after polynomial expansion", 0) } } + + test("read/write") { + val t = new PolynomialExpansion() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setDegree(3) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index b2bdd8935f..3a4f6d235a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -18,11 +18,14 @@ package org.apache.spark.ml.feature import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.{SparkContext, SparkFunSuite} -class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class QuantileDiscretizerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import org.apache.spark.ml.feature.QuantileDiscretizerSuite._ test("Test quantile discretizer") { @@ -67,6 +70,14 @@ class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.") } } + + test("read/write") { + val t = new QuantileDiscretizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setNumBuckets(6) + testDefaultReadWrite(t) + } } private object QuantileDiscretizerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index d19052881a..553e0b8702 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -19,9 +19,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { +class SQLTransformerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new SQLTransformer()) @@ -41,4 +43,10 @@ class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(resultSchema == expected.schema) assert(result.collect().toSeq == expected.collect().toSeq) } + + test("read/write") { + val t = new SQLTransformer() + .setStatement("select * from __THIS__") + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index e0d433f566..fb217e0c1d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @@ -32,7 +33,9 @@ object StopWordsRemoverSuite extends SparkFunSuite { } } -class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { +class StopWordsRemoverSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import StopWordsRemoverSuite._ test("StopWordsRemover default") { @@ -77,4 +80,13 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext { testStopWordsRemover(remover, dataSet) } + + test("read/write") { + val t = new StopWordsRemover() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setStopWords(Array("the", "a")) + .setCaseSensitive(true) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index ddcdb5f421..be37bfb438 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -21,12 +21,13 @@ import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleTy import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { +class StringIndexerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new StringIndexer) @@ -173,4 +174,12 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext { val outSchema = idxToStr.transformSchema(inSchema) assert(outSchema("output").dataType === StringType) } + + test("read/write") { + val t = new IndexToString() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setLabels(Array("a", "b", "c")) + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index a02992a240..36e8e5d868 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -21,20 +21,30 @@ import scala.beans.BeanInfo import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) -class TokenizerSuite extends SparkFunSuite { +class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new Tokenizer) } + + test("read/write") { + val t = new Tokenizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) + } } -class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { +class RegexTokenizerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + import org.apache.spark.ml.feature.RegexTokenizerSuite._ test("params") { @@ -81,6 +91,17 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext { )) testRegexTokenizer(tokenizer, dataset) } + + test("read/write") { + val t = new RegexTokenizer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMinTokenLength(2) + .setGaps(false) + .setPattern("hi") + .setToLowercase(false) + testDefaultReadWrite(t) + } } object RegexTokenizerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index bb4d5b983e..fb21ab6b9b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -25,7 +26,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col -class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { +class VectorAssemblerSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { ParamsSuite.checkParams(new VectorAssembler) @@ -101,4 +103,11 @@ class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext { assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5)) assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6)) } + + test("read/write") { + val t = new VectorAssembler() + .setInputCols(Array("myInputCol", "myInputCol2")) + .setOutputCol("myOutputCol") + testDefaultReadWrite(t) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index a6c2fba836..74706a23e0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -20,12 +20,13 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, Row, SQLContext} -class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext { +class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { test("params") { val slicer = new VectorSlicer @@ -106,4 +107,13 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext { vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4")) validateResults(vectorSlicer.transform(df)) } + + test("read/write") { + val t = new VectorSlicer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setIndices(Array(1, 3)) + .setNames(Array("a", "d")) + testDefaultReadWrite(t) + } } |