aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java7
1 files changed, 4 insertions, 3 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
index cc5a4ef4c2..a3fcdb54ee 100644
--- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
@@ -29,14 +29,15 @@ import static org.junit.Assert.assertTrue;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
public class JavaKMeansSuite implements Serializable {
private transient int k = 5;
private transient JavaSparkContext sc;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
private transient SQLContext sql;
@Before
@@ -61,7 +62,7 @@ public class JavaKMeansSuite implements Serializable {
Vector[] centers = model.clusterCenters();
assertEquals(k, centers.length);
- DataFrame transformed = model.transform(dataset);
+ Dataset<Row> transformed = model.transform(dataset);
List<String> columns = Arrays.asList(transformed.columns());
List<String> expectedColumns = Arrays.asList("features", "prediction");
for (String column: expectedColumns) {