aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-11-18 15:47:49 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-18 15:47:49 -0800
commit7e987de1770f4ab3d54bc05db8de0a1ef035941d (patch)
tree856cbb3cf219827d4022b40675e3b79300ed91e1 /mllib/src/test
parent5df08949f5d9e5b4b0e9c2db50c1b4eb93383de3 (diff)
downloadspark-7e987de1770f4ab3d54bc05db8de0a1ef035941d.tar.gz
spark-7e987de1770f4ab3d54bc05db8de0a1ef035941d.tar.bz2
spark-7e987de1770f4ab3d54bc05db8de0a1ef035941d.zip
[SPARK-6787][ML] add read/write to estimators under ml.feature (1)
Add read/write support to the following estimators under spark.ml: * CountVectorizer * IDF * MinMaxScaler * StandardScaler (a little awkward because we store some params in spark.mllib model) * StringIndexer Added some necessary method for read/write. Maybe we should add `private[ml] trait DefaultParamsReadable` and `DefaultParamsWritable` to save some boilerplate code, though we still need to override `load` for Java compatibility. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #9798 from mengxr/SPARK-6787.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala24
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala19
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala25
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala64
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala19
5 files changed, 129 insertions, 22 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index e192fa4850..9c99990173 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -18,14 +18,17 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
-class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
test("params") {
+ ParamsSuite.checkParams(new CountVectorizer)
ParamsSuite.checkParams(new CountVectorizerModel(Array("empty")))
}
@@ -164,4 +167,23 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(features ~== expected absTol 1e-14)
}
}
+
+ test("CountVectorizer read/write") {
+ val t = new CountVectorizer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setMinDF(0.5)
+ .setMinTF(3.0)
+ .setVocabSize(10)
+ testDefaultReadWrite(t)
+ }
+
+ test("CountVectorizerModel read/write") {
+ val instance = new CountVectorizerModel("myCountVectorizerModel", Array("a", "b", "c"))
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setMinTF(3.0)
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.vocabulary === instance.vocabulary)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
index 08f80af034..bc958c1585 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/IDFSuite.scala
@@ -19,13 +19,14 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.feature.{IDFModel => OldIDFModel}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
-class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
+class IDFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
def scaleDataWithIDF(dataSet: Array[Vector], model: Vector): Array[Vector] = {
dataSet.map {
@@ -98,4 +99,20 @@ class IDFSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.")
}
}
+
+ test("IDF read/write") {
+ val t = new IDF()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setMinDocFreq(5)
+ testDefaultReadWrite(t)
+ }
+
+ test("IDFModel read/write") {
+ val instance = new IDFModel("myIDFModel", new OldIDFModel(Vectors.dense(1.0, 2.0)))
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.idf === instance.idf)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index c04dda41ee..09183fe65b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -18,12 +18,12 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SQLContext}
-class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("MinMaxScaler fit basic case") {
val sqlContext = new SQLContext(sc)
@@ -69,4 +69,25 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
}
+
+ test("MinMaxScaler read/write") {
+ val t = new MinMaxScaler()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setMax(1.0)
+ .setMin(-1.0)
+ testDefaultReadWrite(t)
+ }
+
+ test("MinMaxScalerModel read/write") {
+ val instance = new MinMaxScalerModel(
+ "myMinMaxScalerModel", Vectors.dense(-1.0, 0.0), Vectors.dense(1.0, 10.0))
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setMin(-1.0)
+ .setMax(1.0)
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.originalMin === instance.originalMin)
+ assert(newInstance.originalMax === instance.originalMax)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
index 879a3ae875..49a4b2efe0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala
@@ -19,12 +19,16 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
-import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
-class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{
+class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext
+ with DefaultReadWriteTest {
@transient var data: Array[Vector] = _
@transient var resWithStd: Array[Vector] = _
@@ -56,23 +60,29 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{
)
}
- def assertResult(dataframe: DataFrame): Unit = {
- dataframe.select("standarded_features", "expected").collect().foreach {
+ def assertResult(df: DataFrame): Unit = {
+ df.select("standardized_features", "expected").collect().foreach {
case Row(vector1: Vector, vector2: Vector) =>
assert(vector1 ~== vector2 absTol 1E-5,
"The vector value is not correct after standardization.")
}
}
+ test("params") {
+ ParamsSuite.checkParams(new StandardScaler)
+ val oldModel = new feature.StandardScalerModel(Vectors.dense(1.0), Vectors.dense(2.0))
+ ParamsSuite.checkParams(new StandardScalerModel("empty", oldModel))
+ }
+
test("Standardization with default parameter") {
val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected")
- val standardscaler0 = new StandardScaler()
+ val standardScaler0 = new StandardScaler()
.setInputCol("features")
- .setOutputCol("standarded_features")
+ .setOutputCol("standardized_features")
.fit(df0)
- assertResult(standardscaler0.transform(df0))
+ assertResult(standardScaler0.transform(df0))
}
test("Standardization with setter") {
@@ -80,29 +90,49 @@ class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{
val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected")
val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected")
- val standardscaler1 = new StandardScaler()
+ val standardScaler1 = new StandardScaler()
.setInputCol("features")
- .setOutputCol("standarded_features")
+ .setOutputCol("standardized_features")
.setWithMean(true)
.setWithStd(true)
.fit(df1)
- val standardscaler2 = new StandardScaler()
+ val standardScaler2 = new StandardScaler()
.setInputCol("features")
- .setOutputCol("standarded_features")
+ .setOutputCol("standardized_features")
.setWithMean(true)
.setWithStd(false)
.fit(df2)
- val standardscaler3 = new StandardScaler()
+ val standardScaler3 = new StandardScaler()
.setInputCol("features")
- .setOutputCol("standarded_features")
+ .setOutputCol("standardized_features")
.setWithMean(false)
.setWithStd(false)
.fit(df3)
- assertResult(standardscaler1.transform(df1))
- assertResult(standardscaler2.transform(df2))
- assertResult(standardscaler3.transform(df3))
+ assertResult(standardScaler1.transform(df1))
+ assertResult(standardScaler2.transform(df2))
+ assertResult(standardScaler3.transform(df3))
+ }
+
+ test("StandardScaler read/write") {
+ val t = new StandardScaler()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setWithStd(false)
+ .setWithMean(true)
+ testDefaultReadWrite(t)
+ }
+
+ test("StandardScalerModel read/write") {
+ val oldModel = new feature.StandardScalerModel(
+ Vectors.dense(1.0, 2.0), Vectors.dense(3.0, 4.0), false, true)
+ val instance = new StandardScalerModel("myStandardScalerModel", oldModel)
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.std === instance.std)
+ assert(newInstance.mean === instance.mean)
+ assert(newInstance.getWithStd === instance.getWithStd)
+ assert(newInstance.getWithMean === instance.getWithMean)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index be37bfb438..749bfac747 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -118,6 +118,23 @@ class StringIndexerSuite
assert(indexerModel.transform(df).eq(df))
}
+ test("StringIndexer read/write") {
+ val t = new StringIndexer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setHandleInvalid("skip")
+ testDefaultReadWrite(t)
+ }
+
+ test("StringIndexerModel read/write") {
+ val instance = new StringIndexerModel("myStringIndexerModel", Array("a", "b", "c"))
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setHandleInvalid("skip")
+ val newInstance = testDefaultReadWrite(instance)
+ assert(newInstance.labels === instance.labels)
+ }
+
test("IndexToString params") {
val idxToStr = new IndexToString()
ParamsSuite.checkParams(idxToStr)
@@ -175,7 +192,7 @@ class StringIndexerSuite
assert(outSchema("output").dataType === StringType)
}
- test("read/write") {
+ test("IndexToString read/write") {
val t = new IndexToString()
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")