aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala25
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala12
16 files changed, 174 insertions, 22 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
index 9dfa1439cc..6d2d8fe714 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BinarizerSuite.scala
@@ -69,10 +69,10 @@ class BinarizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
}
test("read/write") {
- val binarizer = new Binarizer()
- .setInputCol("feature")
- .setOutputCol("binarized_feature")
+ val t = new Binarizer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
.setThreshold(0.1)
- testDefaultReadWrite(binarizer)
+ testDefaultReadWrite(t)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
index 0eba34fda6..9ea7d43176 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala
@@ -21,13 +21,13 @@ import scala.util.Random
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row}
-class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new Bucketizer)
@@ -112,6 +112,14 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext {
val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x)))
assert(bsResult ~== lsResult absTol 1e-5)
}
+
+ test("read/write") {
+ val t = new Bucketizer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setSplits(Array(0.1, 0.8, 0.9))
+ testDefaultReadWrite(t)
+ }
}
private object BucketizerSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
index 37ed2367c3..0f2aafebaf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala
@@ -22,6 +22,7 @@ import scala.beans.BeanInfo
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@@ -29,7 +30,7 @@ import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class DCTTestData(vec: Vector, wantedVec: Vector)
-class DCTSuite extends SparkFunSuite with MLlibTestSparkContext {
+class DCTSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("forward transform of discrete cosine matches jTransforms result") {
val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray)
@@ -45,6 +46,14 @@ class DCTSuite extends SparkFunSuite with MLlibTestSparkContext {
testDCT(data, inverse)
}
+ test("read/write") {
+ val t = new DCT()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setInverse(true)
+ testDefaultReadWrite(t)
+ }
+
private def testDCT(data: Vector, inverse: Boolean): Unit = {
val expectedResultBuffer = data.toArray.clone()
if (inverse) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 4157b84b29..0dcd0f4946 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
-class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
+class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new HashingTF)
@@ -50,4 +51,12 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
assert(features ~== expected absTol 1e-14)
}
+
+ test("read/write") {
+ val t = new HashingTF()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setNumFeatures(10)
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala
index 2beb62ca08..932d331b47 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.feature
import scala.collection.mutable.ArrayBuilder
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.ParamsSuite
@@ -26,7 +27,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.functions.col
-class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new Interaction())
}
@@ -162,4 +163,11 @@ class InteractionSuite extends SparkFunSuite with MLlibTestSparkContext {
new NumericAttribute(Some("a_2:b_1:c"), Some(9))))
assert(attrs === expectedAttrs)
}
+
+ test("read/write") {
+ val t = new Interaction()
+ .setInputCols(Array("myInputCol", "myInputCol2"))
+ .setOutputCol("myOutputCol")
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
index ab97e3dbc6..58fda29aa1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
@@ -20,13 +20,14 @@ package org.apache.spark.ml.feature
import scala.beans.BeanInfo
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
-class NGramSuite extends SparkFunSuite with MLlibTestSparkContext {
+class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
import org.apache.spark.ml.feature.NGramSuite._
test("default behavior yields bigram features") {
@@ -79,6 +80,14 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext {
)))
testNGram(nGram, dataset)
}
+
+ test("read/write") {
+ val t = new NGram()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setN(3)
+ testDefaultReadWrite(t)
+ }
}
object NGramSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
index 9f03470b7f..de3d438ce8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NormalizerSuite.scala
@@ -18,13 +18,14 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@transient var data: Array[Vector] = _
@transient var dataFrame: DataFrame = _
@@ -104,6 +105,14 @@ class NormalizerSuite extends SparkFunSuite with MLlibTestSparkContext {
assertValues(result, l1Normalized)
}
+
+ test("read/write") {
+ val t = new Normalizer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setP(3.0)
+ testDefaultReadWrite(t)
+ }
}
private object NormalizerSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
index 321eeb8439..76d12050f9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
@@ -20,12 +20,14 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
-class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
+class OneHotEncoderSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
def stringIndexed(): DataFrame = {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
@@ -101,4 +103,12 @@ class OneHotEncoderSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0))
assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1))
}
+
+ test("read/write") {
+ val t = new OneHotEncoder()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setDropLast(false)
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
index 29eebd8960..70892dc571 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala
@@ -21,12 +21,14 @@ import org.apache.spark.ml.param.ParamsSuite
import org.scalatest.exceptions.TestFailedException
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.sql.Row
-class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext {
+class PolynomialExpansionSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new PolynomialExpansion)
@@ -98,5 +100,13 @@ class PolynomialExpansionSuite extends SparkFunSuite with MLlibTestSparkContext
throw new TestFailedException("Unmatched data types after polynomial expansion", 0)
}
}
+
+ test("read/write") {
+ val t = new PolynomialExpansion()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setDegree(3)
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
index b2bdd8935f..3a4f6d235a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala
@@ -18,11 +18,14 @@
package org.apache.spark.ml.feature
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.{SparkContext, SparkFunSuite}
-class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class QuantileDiscretizerSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
import org.apache.spark.ml.feature.QuantileDiscretizerSuite._
test("Test quantile discretizer") {
@@ -67,6 +70,14 @@ class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext
assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.")
}
}
+
+ test("read/write") {
+ val t = new QuantileDiscretizer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setNumBuckets(6)
+ testDefaultReadWrite(t)
+ }
}
private object QuantileDiscretizerSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
index d19052881a..553e0b8702 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
@@ -19,9 +19,11 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class SQLTransformerSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new SQLTransformer())
@@ -41,4 +43,10 @@ class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(resultSchema == expected.schema)
assert(result.collect().toSeq == expected.collect().toSeq)
}
+
+ test("read/write") {
+ val t = new SQLTransformer()
+ .setStatement("select * from __THIS__")
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index e0d433f566..fb217e0c1d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@@ -32,7 +33,9 @@ object StopWordsRemoverSuite extends SparkFunSuite {
}
}
-class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
+class StopWordsRemoverSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
import StopWordsRemoverSuite._
test("StopWordsRemover default") {
@@ -77,4 +80,13 @@ class StopWordsRemoverSuite extends SparkFunSuite with MLlibTestSparkContext {
testStopWordsRemover(remover, dataSet)
}
+
+ test("read/write") {
+ val t = new StopWordsRemover()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setStopWords(Array("the", "a"))
+ .setCaseSensitive(true)
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index ddcdb5f421..be37bfb438 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -21,12 +21,13 @@ import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleTy
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
-class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class StringIndexerSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new StringIndexer)
@@ -173,4 +174,12 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
val outSchema = idxToStr.transformSchema(inSchema)
assert(outSchema("output").dataType === StringType)
}
+
+ test("read/write") {
+ val t = new IndexToString()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setLabels(Array("a", "b", "c"))
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index a02992a240..36e8e5d868 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -21,20 +21,30 @@ import scala.beans.BeanInfo
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
-class TokenizerSuite extends SparkFunSuite {
+class TokenizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new Tokenizer)
}
+
+ test("read/write") {
+ val t = new Tokenizer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ testDefaultReadWrite(t)
+ }
}
-class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class RegexTokenizerSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
import org.apache.spark.ml.feature.RegexTokenizerSuite._
test("params") {
@@ -81,6 +91,17 @@ class RegexTokenizerSuite extends SparkFunSuite with MLlibTestSparkContext {
))
testRegexTokenizer(tokenizer, dataset)
}
+
+ test("read/write") {
+ val t = new RegexTokenizer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setMinTokenLength(2)
+ .setGaps(false)
+ .setPattern("hi")
+ .setToLowercase(false)
+ testDefaultReadWrite(t)
+ }
}
object RegexTokenizerSuite extends SparkFunSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
index bb4d5b983e..fb21ab6b9b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ml.feature
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
import org.apache.spark.ml.param.ParamsSuite
@@ -25,7 +26,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
-class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class VectorAssemblerSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
ParamsSuite.checkParams(new VectorAssembler)
@@ -101,4 +103,11 @@ class VectorAssemblerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(features.getAttr(5) === NumericAttribute.defaultAttr.withIndex(5))
assert(features.getAttr(6) === NumericAttribute.defaultAttr.withIndex(6))
}
+
+ test("read/write") {
+ val t = new VectorAssembler()
+ .setInputCols(Array("myInputCol", "myInputCol2"))
+ .setOutputCol("myOutputCol")
+ testDefaultReadWrite(t)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
index a6c2fba836..74706a23e0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext {
+class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
val slicer = new VectorSlicer
@@ -106,4 +107,13 @@ class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext {
vectorSlicer.setIndices(Array.empty).setNames(Array("f1", "f4"))
validateResults(vectorSlicer.transform(df))
}
+
+ test("read/write") {
+ val t = new VectorSlicer()
+ .setInputCol("myInputCol")
+ .setOutputCol("myOutputCol")
+ .setIndices(Array(1, 3))
+ .setNames(Array("a", "d"))
+ testDefaultReadWrite(t)
+ }
}