aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/java')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java6
1 files changed, 1 insertions, 5 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java
index 4b13ba6f9c..7f568f4e0d 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java
@@ -29,7 +29,6 @@ import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
-import org.apache.spark.sql.types.DataTypes;
// $example off$
public class JavaALSExample {
@@ -109,10 +108,7 @@ public class JavaALSExample {
ALSModel model = als.fit(training);
// Evaluate the model by computing the RMSE on the test data
- Dataset<Row> rawPredictions = model.transform(test);
- Dataset<Row> predictions = rawPredictions
- .withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType))
- .withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType));
+ Dataset<Row> predictions = model.transform(test);
RegressionEvaluator evaluator = new RegressionEvaluator()
.setMetricName("rmse")