diff options
author | Kousuke Saruta <sarutak@oss.nttdata.co.jp> | 2015-12-29 05:33:19 +0900 |
---|---|---|
committer | Kousuke Saruta <sarutak@oss.nttdata.co.jp> | 2015-12-29 05:33:19 +0900 |
commit | 07165ca06fe0866677525f85fec25e4dbd336674 (patch) | |
tree | 5f949cda3a57bacd56ced1b3bd620e7724dece50 /mllib/src/test/scala | |
parent | e01c6c8664d74d434e9b6b3c8c70570f01d4a0a4 (diff) | |
download | spark-07165ca06fe0866677525f85fec25e4dbd336674.tar.gz spark-07165ca06fe0866677525f85fec25e4dbd336674.tar.bz2 spark-07165ca06fe0866677525f85fec25e4dbd336674.zip |
[SPARK-12424][ML] The implementation of ParamMap#filter is wrong.
ParamMap#filter uses `mutable.Map#filterKeys`. The return type of `filterKey` is collection.Map, not mutable.Map but the result is casted to mutable.Map using `asInstanceOf` so we get `ClassCastException`.
Also, the return type of Map#filterKeys is not Serializable. It's the issue of Scala (https://issues.scala-lang.org/browse/SI-6654).
Author: Kousuke Saruta <sarutak@oss.nttdata.co.jp>
Closes #10381 from sarutak/SPARK-12424.
Diffstat (limited to 'mllib/src/test/scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index a1878be747..748868554f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.ml.param +import java.io.{ByteArrayOutputStream, NotSerializableException, ObjectOutputStream} + import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MyParams import org.apache.spark.mllib.linalg.{Vector, Vectors} class ParamsSuite extends SparkFunSuite { @@ -349,6 +352,31 @@ class ParamsSuite extends SparkFunSuite { val t3 = t.copy(ParamMap(t.maxIter -> 20)) assert(t3.isSet(t3.maxIter)) } + + test("Filtering ParamMap") { + val params1 = new MyParams("my_params1") + val params2 = new MyParams("my_params2") + val paramMap = ParamMap( + params1.intParam -> 1, + params2.intParam -> 1, + params1.doubleParam -> 0.2, + params2.doubleParam -> 0.2) + val filteredParamMap = paramMap.filter(params1) + + assert(filteredParamMap.size === 2) + filteredParamMap.toSeq.foreach { + case ParamPair(p, _) => + assert(p.parent === params1.uid) + } + + // At the previous implementation of ParamMap#filter, + // mutable.Map#filterKeys was used internally but + // the return type of the method is not serializable (see SI-6654). + // Now mutable.Map#filter is used instead of filterKeys and the return type is serializable. + // So let's ensure serializability. + val objOut = new ObjectOutputStream(new ByteArrayOutputStream()) + objOut.writeObject(filteredParamMap) + } } object ParamsSuite extends SparkFunSuite { |