aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-11-18 18:34:01 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-18 18:34:01 -0800
commite99d3392068bc929c900a4cc7b50e9e2b437a23a (patch)
treeecc75c1b1d75173742d95c4156cce653c12418b2 /mllib/src/test
parent59a501359a267fbdb7689058693aa788703e54b1 (diff)
downloadspark-e99d3392068bc929c900a4cc7b50e9e2b437a23a.tar.gz
spark-e99d3392068bc929c900a4cc7b50e9e2b437a23a.tar.bz2
spark-e99d3392068bc929c900a4cc7b50e9e2b437a23a.zip
[SPARK-11839][ML] refactor save/write traits
* add "ML" prefix to reader/writer/readable/writable to avoid name collision with java.util.* * define `DefaultParamsReadable/Writable` and use them to save some code * use `super.load` instead so people can jump directly to the doc of `Readable.load`, which documents the Java compatibility issues jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #9827 from mengxr/SPARK-11839.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala17
2 files changed, 16 insertions, 15 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 7f5c3895ac..12aba6bc6d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -179,8 +179,8 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
-/** Used to test [[Pipeline]] with [[Writable]] stages */
-class WritableStage(override val uid: String) extends Transformer with Writable {
+/** Used to test [[Pipeline]] with [[MLWritable]] stages */
+class WritableStage(override val uid: String) extends Transformer with MLWritable {
final val intParam: IntParam = new IntParam(this, "intParam", "doc")
@@ -192,21 +192,21 @@ class WritableStage(override val uid: String) extends Transformer with Writable
override def copy(extra: ParamMap): WritableStage = defaultCopy(extra)
- override def write: Writer = new DefaultParamsWriter(this)
+ override def write: MLWriter = new DefaultParamsWriter(this)
override def transform(dataset: DataFrame): DataFrame = dataset
override def transformSchema(schema: StructType): StructType = schema
}
-object WritableStage extends Readable[WritableStage] {
+object WritableStage extends MLReadable[WritableStage] {
- override def read: Reader[WritableStage] = new DefaultParamsReader[WritableStage]
+ override def read: MLReader[WritableStage] = new DefaultParamsReader[WritableStage]
- override def load(path: String): WritableStage = read.load(path)
+ override def load(path: String): WritableStage = super.load(path)
}
-/** Used to test [[Pipeline]] with non-[[Writable]] stages */
+/** Used to test [[Pipeline]] with non-[[MLWritable]] stages */
class UnWritableStage(override val uid: String) extends Transformer {
final val intParam: IntParam = new IntParam(this, "intParam", "doc")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index dd1e8acce9..84d06b43d6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -38,7 +38,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
* @tparam T ML instance type
* @return Instance loaded from file
*/
- def testDefaultReadWrite[T <: Params with Writable](
+ def testDefaultReadWrite[T <: Params with MLWritable](
instance: T,
testParams: Boolean = true): T = {
val uid = instance.uid
@@ -52,7 +52,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
instance.save(path)
}
instance.write.overwrite().save(path)
- val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[Reader[T]]
+ val loader = instance.getClass.getMethod("read").invoke(null).asInstanceOf[MLReader[T]]
val newInstance = loader.load(path)
assert(newInstance.uid === instance.uid)
@@ -92,7 +92,8 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
* @tparam E Type of [[Estimator]]
* @tparam M Type of [[Model]] produced by estimator
*/
- def testEstimatorAndModelReadWrite[E <: Estimator[M] with Writable, M <: Model[M] with Writable](
+ def testEstimatorAndModelReadWrite[
+ E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable](
estimator: E,
dataset: DataFrame,
testParams: Map[String, Any],
@@ -119,7 +120,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
}
}
-class MyParams(override val uid: String) extends Params with Writable {
+class MyParams(override val uid: String) extends Params with MLWritable {
final val intParamWithDefault: IntParam = new IntParam(this, "intParamWithDefault", "doc")
final val intParam: IntParam = new IntParam(this, "intParam", "doc")
@@ -145,14 +146,14 @@ class MyParams(override val uid: String) extends Params with Writable {
override def copy(extra: ParamMap): Params = defaultCopy(extra)
- override def write: Writer = new DefaultParamsWriter(this)
+ override def write: MLWriter = new DefaultParamsWriter(this)
}
-object MyParams extends Readable[MyParams] {
+object MyParams extends MLReadable[MyParams] {
- override def read: Reader[MyParams] = new DefaultParamsReader[MyParams]
+ override def read: MLReader[MyParams] = new DefaultParamsReader[MyParams]
- override def load(path: String): MyParams = read.load(path)
+ override def load(path: String): MyParams = super.load(path)
}
class DefaultReadWriteSuite extends SparkFunSuite with MLlibTestSparkContext