aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorNick Pentreath <nickp@za.ibm.com>2016-05-07 10:57:40 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-05-07 10:57:40 +0200
commitb0cafdb6ccff9add89dc31c45adf87c8fa906aac (patch)
tree1ce9876b0c6237387283cee2ff021dfb6815e0c4 /examples
parentdf89f1d43d4eaa1dd8a439a8e48bca16b67d5b48 (diff)
downloadspark-b0cafdb6ccff9add89dc31c45adf87c8fa906aac.tar.gz
spark-b0cafdb6ccff9add89dc31c45adf87c8fa906aac.tar.bz2
spark-b0cafdb6ccff9add89dc31c45adf87c8fa906aac.zip
[MINOR][ML][PYSPARK] ALS example cleanup
Cleans up ALS examples by removing unnecessary casts to double for `rating` and `prediction` columns, since `RegressionEvaluator` now supports `Double` & `Float` input types. ## How was this patch tested? Manual compile and run with `run-example ml.ALSExample` and `spark-submit examples/src/main/python/ml/als_example.py`. Author: Nick Pentreath <nickp@za.ibm.com> Closes #12892 from MLnick/als-examples-cleanup.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java6
-rw-r--r--examples/src/main/python/ml/als_example.py9
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala6
3 files changed, 4 insertions, 17 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java
index 4b13ba6f9c..7f568f4e0d 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java
@@ -29,7 +29,6 @@ import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
-import org.apache.spark.sql.types.DataTypes;
// $example off$
public class JavaALSExample {
@@ -109,10 +108,7 @@ public class JavaALSExample {
ALSModel model = als.fit(training);
// Evaluate the model by computing the RMSE on the test data
- Dataset<Row> rawPredictions = model.transform(test);
- Dataset<Row> predictions = rawPredictions
- .withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType))
- .withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType));
+ Dataset<Row> predictions = model.transform(test);
RegressionEvaluator evaluator = new RegressionEvaluator()
.setMetricName("rmse")
diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py
index ff0829b0dd..1a979ff5b5 100644
--- a/examples/src/main/python/ml/als_example.py
+++ b/examples/src/main/python/ml/als_example.py
@@ -48,12 +48,9 @@ if __name__ == "__main__":
model = als.fit(training)
# Evaluate the model by computing the RMSE on the test data
- rawPredictions = model.transform(test)
- predictions = rawPredictions\
- .withColumn("rating", rawPredictions.rating.cast("double"))\
- .withColumn("prediction", rawPredictions.prediction.cast("double"))
- evaluator =\
- RegressionEvaluator(metricName="rmse", labelCol="rating", predictionCol="prediction")
+ predictions = model.transform(test)
+ evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating",
+ predictionCol="prediction")
rmse = evaluator.evaluate(predictions)
print("Root-mean-square error = " + str(rmse))
# $example off$
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala
index 7c1cfe2937..6b151a622e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala
@@ -23,10 +23,6 @@ import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.recommendation.ALS
// $example off$
import org.apache.spark.sql.SparkSession
-// $example on$
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.DoubleType
-// $example off$
object ALSExample {
@@ -65,8 +61,6 @@ object ALSExample {
// Evaluate the model by computing the RMSE on the test data
val predictions = model.transform(test)
- .withColumn("rating", col("rating").cast(DoubleType))
- .withColumn("prediction", col("prediction").cast(DoubleType))
val evaluator = new RegressionEvaluator()
.setMetricName("rmse")