aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java2
5 files changed, 6 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 5d51c51346..324b1ba784 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -76,8 +76,8 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
val metrics = new Array[Double](epm.size)
val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0)
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
- val trainingDataset = sqlCtx.applySchema(training, schema).cache()
- val validationDataset = sqlCtx.applySchema(validation, schema).cache()
+ val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
+ val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
// multi-model training
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 50995ffef9..0a8c9e5954 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -45,7 +45,7 @@ public class JavaPipelineSuite {
jsql = new SQLContext(jsc);
JavaRDD<LabeledPoint> points =
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
- dataset = jsql.applySchema(points, LabeledPoint.class);
+ dataset = jsql.createDataFrame(points, LabeledPoint.class);
}
@After
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index d4b6644792..3f8e59de0f 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -50,7 +50,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
jsql = new SQLContext(jsc);
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+ dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 40d5a92bb3..0cc36c8d56 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -46,7 +46,7 @@ public class JavaLinearRegressionSuite implements Serializable {
jsql = new SQLContext(jsc);
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+ dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index 074b58c07d..0bb6b489f2 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -45,7 +45,7 @@ public class JavaCrossValidatorSuite implements Serializable {
jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
jsql = new SQLContext(jsc);
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
- dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+ dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
}
@After