aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/recommendation.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/recommendation.py')
-rw-r--r--python/pyspark/ml/recommendation.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index 86c00d9165..bac2a3029f 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -65,7 +65,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
indicated user preferences rather than explicit ratings given to
items.
- >>> df = sqlContext.createDataFrame(
+ >>> df = spark.createDataFrame(
... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
... ["user", "item", "rating"])
>>> als = ALS(rank=10, maxIter=5)
@@ -74,7 +74,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
10
>>> model.userFactors.orderBy("id").collect()
[Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]
- >>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
+ >>> test = spark.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
>>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
>>> predictions[0]
Row(user=0, item=2, prediction=-0.13807615637779236)
@@ -370,21 +370,23 @@ class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
if __name__ == "__main__":
import doctest
import pyspark.ml.recommendation
- from pyspark.context import SparkContext
- from pyspark.sql import SQLContext
+ from pyspark.sql import SparkSession
globs = pyspark.ml.recommendation.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
- sc = SparkContext("local[2]", "ml.recommendation tests")
- sqlContext = SQLContext(sc)
+ spark = SparkSession.builder\
+ .master("local[2]")\
+ .appName("ml.recommendation tests")\
+ .getOrCreate()
+ sc = spark.sparkContext
globs['sc'] = sc
- globs['sqlContext'] = sqlContext
+ globs['spark'] = spark
import tempfile
temp_path = tempfile.mkdtemp()
globs['temp_path'] = temp_path
try:
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
- sc.stop()
+ spark.stop()
finally:
from shutil import rmtree
try: