aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2016-04-11 09:28:28 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-11 09:28:28 -0700
commit1c751fcf488189e5176546fe0d00f560ffcf1cec (patch)
tree40863cdb5ac52b6fdc74a22d64853ea07826e6be /mllib/src/test
parente82d95bf63f57cefa02dc545ceb451ecdeedce28 (diff)
downloadspark-1c751fcf488189e5176546fe0d00f560ffcf1cec.tar.gz
spark-1c751fcf488189e5176546fe0d00f560ffcf1cec.tar.bz2
spark-1c751fcf488189e5176546fe0d00f560ffcf1cec.zip
[SPARK-14500] [ML] Accept Dataset[_] instead of DataFrame in MLlib APIs
## What changes were proposed in this pull request? This PR updates MLlib APIs to accept `Dataset[_]` as input where `DataFrame` was the input type. This PR doesn't change the output type. In Java, `Dataset[_]` maps to `Dataset<?>`, which includes `Dataset<Row>`. Some implementations were changed in order to return `DataFrame`. Tests and examples were updated. Note that this is a breaking change for subclasses of Transformer/Estimator. Lol, we don't have to rename the input argument, which has been `dataset` since Spark 1.2. TODOs: - [x] update MiMaExcludes (seems all covered by explicit filters from SPARK-13920) - [x] Python - [x] add a new test to accept Dataset[LabeledPoint] - [x] remove unused imports of Dataset ## How was this patch tested? Exiting unit tests with some modifications. cc: rxin jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #12274 from mengxr/SPARK-14500.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala4
17 files changed, 50 insertions, 36 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index f3321fb5a1..a8c4ac6d05 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -51,6 +51,12 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val dataset3 = mock[DataFrame]
val dataset4 = mock[DataFrame]
+ when(dataset0.toDF).thenReturn(dataset0)
+ when(dataset1.toDF).thenReturn(dataset1)
+ when(dataset2.toDF).thenReturn(dataset2)
+ when(dataset3.toDF).thenReturn(dataset3)
+ when(dataset4.toDF).thenReturn(dataset4)
+
when(estimator0.copy(any[ParamMap])).thenReturn(estimator0)
when(model0.copy(any[ParamMap])).thenReturn(model0)
when(transformer1.copy(any[ParamMap])).thenReturn(transformer1)
@@ -213,7 +219,7 @@ class WritableStage(override val uid: String) extends Transformer with MLWritabl
override def write: MLWriter = new DefaultParamsWriter(this)
- override def transform(dataset: DataFrame): DataFrame = dataset
+ override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF
override def transformSchema(schema: StructType): StructType = schema
}
@@ -234,7 +240,7 @@ class UnWritableStage(override val uid: String) extends Transformer {
override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra)
- override def transform(dataset: DataFrame): DataFrame = dataset
+ override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF
override def transformSchema(schema: StructType): StructType = schema
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 7eefaf2346..48db428130 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -29,13 +29,13 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.lit
class LogisticRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
@transient var binaryDataset: DataFrame = _
private val eps: Double = 1e-5
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 06ff049b48..80547fad6a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -26,12 +26,12 @@ import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
class MultilayerPerceptronClassifierSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 4727cd436f..80a46fc70c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -27,11 +27,11 @@ import org.apache.spark.mllib.classification.NaiveBayesSuite._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 4131396726..f3e8fd11b2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -30,12 +30,12 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.Metadata
class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
@transient var rdd: RDD[LabeledPoint] = _
override def beforeAll(): Unit = {
@@ -246,7 +246,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid
setMaxIter(1)
- override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
+ override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
val labelSchema = dataset.schema($(labelCol))
// check for label attribute propagation.
assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index 18f2c994b4..e641d79c17 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -20,13 +20,13 @@ package org.apache.spark.ml.clustering
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
class BisectingKMeansSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
final val k = 5
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 8edd44e5f1..1a274aea29 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -20,14 +20,14 @@ package org.apache.spark.ml.clustering
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
with DefaultReadWriteTest {
final val k = 5
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index c684bc11cc..2076c745e2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -22,14 +22,14 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
private[clustering] case class TestRow(features: Vector)
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
final val k = 5
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index a1c93891c7..ee8eae8f69 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
object LDASuite {
@@ -64,7 +64,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
val k: Int = 5
val vocabSize: Int = 30
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
index 58fda29aa1..e4e15f4331 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
@@ -22,7 +22,7 @@ import scala.beans.BeanInfo
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
@BeanInfo
case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
@@ -92,7 +92,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
object NGramSuite extends SparkFunSuite {
- def testNGram(t: NGram, dataset: DataFrame): Unit = {
+ def testNGram(t: NGram, dataset: Dataset[_]): Unit = {
t.transform(dataset)
.select("nGrams", "wantedNGrams")
.collect()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index a5b24c1856..3505befdf8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -20,10 +20,10 @@ package org.apache.spark.ml.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
object StopWordsRemoverSuite extends SparkFunSuite {
- def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = {
+ def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = {
t.transform(dataset)
.select("filtered", "expected")
.collect()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 2c3255ef33..d0f3cdc841 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -115,7 +115,7 @@ class StringIndexerSuite
.setInputCol("label")
.setOutputCol("labelIndex")
val df = sqlContext.range(0L, 10L).toDF()
- assert(indexerModel.transform(df).eq(df))
+ assert(indexerModel.transform(df).collect().toSet === df.collect().toSet)
}
test("StringIndexerModel can't overwrite output column") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index 36e8e5d868..299f6223b2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
@BeanInfo
case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
@@ -106,7 +106,7 @@ class RegexTokenizerSuite
object RegexTokenizerSuite extends SparkFunSuite {
- def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
+ def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = {
t.transform(dataset)
.select("tokens", "wantedTokens")
.collect()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 2265464b51..4905f3e068 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -992,6 +992,14 @@ class GeneralizedLinearRegressionSuite
assert(expected.coefficients === actual.coefficients)
}
}
+
+ test("glm accepts Dataset[LabeledPoint]") {
+ val context = sqlContext
+ import context.implicits._
+ new GeneralizedLinearRegression()
+ .setFamily("gaussian")
+ .fit(datasetGaussianIdentity.as[LabeledPoint])
+ }
}
object GeneralizedLinearRegressionSuite {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 7af3c6d6ed..3e734aabc5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -29,13 +29,13 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.{StructField, StructType}
class CrossValidatorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -311,7 +311,7 @@ object CrossValidatorSuite extends SparkFunSuite {
class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
- override def fit(dataset: DataFrame): MyModel = {
+ override def fit(dataset: Dataset[_]): MyModel = {
throw new UnsupportedOperationException
}
@@ -325,7 +325,7 @@ object CrossValidatorSuite extends SparkFunSuite {
class MyEvaluator extends Evaluator {
- override def evaluate(dataset: DataFrame): Double = {
+ override def evaluate(dataset: Dataset[_]): Double = {
throw new UnsupportedOperationException
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 4030956fab..dbee47c847 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
class TrainValidationSplitSuite
@@ -158,7 +158,7 @@ object TrainValidationSplitSuite {
class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol {
- override def fit(dataset: DataFrame): MyModel = {
+ override def fit(dataset: Dataset[_]): MyModel = {
throw new UnsupportedOperationException
}
@@ -172,7 +172,7 @@ object TrainValidationSplitSuite {
class MyEvaluator extends Evaluator {
- override def evaluate(dataset: DataFrame): Double = {
+ override def evaluate(dataset: Dataset[_]): Double = {
throw new UnsupportedOperationException
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 16280473c6..7ebd7eb144 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
@@ -98,7 +98,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
def testEstimatorAndModelReadWrite[
E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable](
estimator: E,
- dataset: DataFrame,
+ dataset: Dataset[_],
testParams: Map[String, Any],
checkModelData: (M, M) => Unit): Unit = {
// Set some Params to make sure set Params are serialized.