aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKousuke Saruta <sarutak@oss.nttdata.co.jp>2015-12-29 05:33:19 +0900
committerKousuke Saruta <sarutak@oss.nttdata.co.jp>2015-12-29 05:33:19 +0900
commit07165ca06fe0866677525f85fec25e4dbd336674 (patch)
tree5f949cda3a57bacd56ced1b3bd620e7724dece50
parente01c6c8664d74d434e9b6b3c8c70570f01d4a0a4 (diff)
downloadspark-07165ca06fe0866677525f85fec25e4dbd336674.tar.gz
spark-07165ca06fe0866677525f85fec25e4dbd336674.tar.bz2
spark-07165ca06fe0866677525f85fec25e4dbd336674.zip
[SPARK-12424][ML] The implementation of ParamMap#filter is wrong.
ParamMap#filter uses `mutable.Map#filterKeys`. The return type of `filterKey` is collection.Map, not mutable.Map but the result is casted to mutable.Map using `asInstanceOf` so we get `ClassCastException`. Also, the return type of Map#filterKeys is not Serializable. It's the issue of Scala (https://issues.scala-lang.org/browse/SI-6654). Author: Kousuke Saruta <sarutak@oss.nttdata.co.jp> Closes #10381 from sarutak/SPARK-12424.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala28
2 files changed, 34 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index ee7e89edd8..c0546695e4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -859,8 +859,12 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any])
* Filters this param map for the given parent.
*/
def filter(parent: Params): ParamMap = {
- val filtered = map.filterKeys(_.parent == parent)
- new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]])
+ // Don't use filterKeys because mutable.Map#filterKeys
+ // returns the instance of collections.Map, not mutable.Map.
+ // Otherwise, we get ClassCastException.
+ // Not using filterKeys also avoid SI-6654
+ val filtered = map.filter { case (k, _) => k.parent == parent.uid }
+ new ParamMap(filtered)
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index a1878be747..748868554f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -17,7 +17,10 @@
package org.apache.spark.ml.param
+import java.io.{ByteArrayOutputStream, NotSerializableException, ObjectOutputStream}
+
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.MyParams
import org.apache.spark.mllib.linalg.{Vector, Vectors}
class ParamsSuite extends SparkFunSuite {
@@ -349,6 +352,31 @@ class ParamsSuite extends SparkFunSuite {
val t3 = t.copy(ParamMap(t.maxIter -> 20))
assert(t3.isSet(t3.maxIter))
}
+
+ test("Filtering ParamMap") {
+ val params1 = new MyParams("my_params1")
+ val params2 = new MyParams("my_params2")
+ val paramMap = ParamMap(
+ params1.intParam -> 1,
+ params2.intParam -> 1,
+ params1.doubleParam -> 0.2,
+ params2.doubleParam -> 0.2)
+ val filteredParamMap = paramMap.filter(params1)
+
+ assert(filteredParamMap.size === 2)
+ filteredParamMap.toSeq.foreach {
+ case ParamPair(p, _) =>
+ assert(p.parent === params1.uid)
+ }
+
+ // At the previous implementation of ParamMap#filter,
+ // mutable.Map#filterKeys was used internally but
+ // the return type of the method is not serializable (see SI-6654).
+ // Now mutable.Map#filter is used instead of filterKeys and the return type is serializable.
+ // So let's ensure serializability.
+ val objOut = new ObjectOutputStream(new ByteArrayOutputStream())
+ objOut.writeObject(filteredParamMap)
+ }
}
object ParamsSuite extends SparkFunSuite {