aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python/ml/als_example.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/python/ml/als_example.py')
-rw-r--r--examples/src/main/python/ml/als_example.py14
1 files changed, 6 insertions, 8 deletions
diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py
index 0c9ac583b2..e36444f185 100644
--- a/examples/src/main/python/ml/als_example.py
+++ b/examples/src/main/python/ml/als_example.py
@@ -21,8 +21,7 @@ import sys
if sys.version >= '3':
long = int
-from pyspark import SparkContext
-from pyspark.sql import SQLContext
+from pyspark.sql import SparkSession
# $example on$
from pyspark.ml.evaluation import RegressionEvaluator
@@ -31,15 +30,14 @@ from pyspark.sql import Row
# $example off$
if __name__ == "__main__":
- sc = SparkContext(appName="ALSExample")
- sqlContext = SQLContext(sc)
+ spark = SparkSession.builder.appName("ALSExample").getOrCreate()
# $example on$
- lines = sc.textFile("data/mllib/als/sample_movielens_ratings.txt")
- parts = lines.map(lambda l: l.split("::"))
+ lines = spark.read.text("data/mllib/als/sample_movielens_ratings.txt").rdd
+ parts = lines.map(lambda row: row.value.split("::"))
ratingsRDD = parts.map(lambda p: Row(userId=int(p[0]), movieId=int(p[1]),
rating=float(p[2]), timestamp=long(p[3])))
- ratings = sqlContext.createDataFrame(ratingsRDD)
+ ratings = spark.createDataFrame(ratingsRDD)
(training, test) = ratings.randomSplit([0.8, 0.2])
# Build the recommendation model using ALS on the training data
@@ -56,4 +54,4 @@ if __name__ == "__main__":
rmse = evaluator.evaluate(predictions)
print("Root-mean-square error = " + str(rmse))
# $example off$
- sc.stop()
+ spark.stop()