diff options
author | Kousuke Saruta <sarutak@oss.nttdata.co.jp> | 2015-12-29 05:33:19 +0900 |
---|---|---|
committer | Kousuke Saruta <sarutak@oss.nttdata.co.jp> | 2015-12-29 05:33:19 +0900 |
commit | 07165ca06fe0866677525f85fec25e4dbd336674 (patch) | |
tree | 5f949cda3a57bacd56ced1b3bd620e7724dece50 | |
parent | e01c6c8664d74d434e9b6b3c8c70570f01d4a0a4 (diff) | |
download | spark-07165ca06fe0866677525f85fec25e4dbd336674.tar.gz spark-07165ca06fe0866677525f85fec25e4dbd336674.tar.bz2 spark-07165ca06fe0866677525f85fec25e4dbd336674.zip |
[SPARK-12424][ML] The implementation of ParamMap#filter is wrong.
ParamMap#filter uses `mutable.Map#filterKeys`. The return type of `filterKey` is collection.Map, not mutable.Map but the result is casted to mutable.Map using `asInstanceOf` so we get `ClassCastException`.
Also, the return type of Map#filterKeys is not Serializable. It's the issue of Scala (https://issues.scala-lang.org/browse/SI-6654).
Author: Kousuke Saruta <sarutak@oss.nttdata.co.jp>
Closes #10381 from sarutak/SPARK-12424.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/param/params.scala | 8 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala | 28 |
2 files changed, 34 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ee7e89edd8..c0546695e4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -859,8 +859,12 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) * Filters this param map for the given parent. */ def filter(parent: Params): ParamMap = { - val filtered = map.filterKeys(_.parent == parent) - new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]]) + // Don't use filterKeys because mutable.Map#filterKeys + // returns the instance of collections.Map, not mutable.Map. + // Otherwise, we get ClassCastException. + // Not using filterKeys also avoid SI-6654 + val filtered = map.filter { case (k, _) => k.parent == parent.uid } + new ParamMap(filtered) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index a1878be747..748868554f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.ml.param +import java.io.{ByteArrayOutputStream, NotSerializableException, ObjectOutputStream} + import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MyParams import org.apache.spark.mllib.linalg.{Vector, Vectors} class ParamsSuite extends SparkFunSuite { @@ -349,6 +352,31 @@ class ParamsSuite extends SparkFunSuite { val t3 = t.copy(ParamMap(t.maxIter -> 20)) assert(t3.isSet(t3.maxIter)) } + + test("Filtering ParamMap") { + val params1 = new MyParams("my_params1") + val params2 = new MyParams("my_params2") + val paramMap = ParamMap( + params1.intParam -> 1, + params2.intParam -> 1, + params1.doubleParam -> 0.2, + params2.doubleParam -> 0.2) + val filteredParamMap = paramMap.filter(params1) + + assert(filteredParamMap.size === 2) + filteredParamMap.toSeq.foreach { + case ParamPair(p, _) => + assert(p.parent === params1.uid) + } + + // At the previous implementation of ParamMap#filter, + // mutable.Map#filterKeys was used internally but + // the return type of the method is not serializable (see SI-6654). + // Now mutable.Map#filter is used instead of filterKeys and the return type is serializable. + // So let's ensure serializability. + val objOut = new ObjectOutputStream(new ByteArrayOutputStream()) + objOut.writeObject(filteredParamMap) + } } object ParamsSuite extends SparkFunSuite { |