aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala10
1 files changed, 5 insertions, 5 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 85df6da7a1..30bd390381 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -20,17 +20,17 @@ package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model, Pipeline}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
+import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
import org.apache.spark.ml.feature.HashingTF
+import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{ParamMap, ParamPair}
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
-import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
-import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
-import org.apache.spark.sql.{DataFrame, Dataset}
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.types.StructType
class CrossValidatorSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -69,7 +69,7 @@ class CrossValidatorSuite
test("cross validation with linear regression") {
val dataset = spark.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
- 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
+ 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2).map(_.asML))
val trainer = new LinearRegression().setSolver("l-bfgs")
val lrParamMaps = new ParamGridBuilder()