aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python/ml/decision_tree_regression_example.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/python/ml/decision_tree_regression_example.py')
-rw-r--r--examples/src/main/python/ml/decision_tree_regression_example.py9
1 files changed, 5 insertions, 4 deletions
diff --git a/examples/src/main/python/ml/decision_tree_regression_example.py b/examples/src/main/python/ml/decision_tree_regression_example.py
index 8e20d5d857..9e8cb382a9 100644
--- a/examples/src/main/python/ml/decision_tree_regression_example.py
+++ b/examples/src/main/python/ml/decision_tree_regression_example.py
@@ -20,21 +20,20 @@ Decision Tree Regression Example.
"""
from __future__ import print_function
-from pyspark import SparkContext, SQLContext
# $example on$
from pyspark.ml import Pipeline
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.ml.feature import VectorIndexer
from pyspark.ml.evaluation import RegressionEvaluator
# $example off$
+from pyspark.sql import SparkSession
if __name__ == "__main__":
- sc = SparkContext(appName="decision_tree_classification_example")
- sqlContext = SQLContext(sc)
+ spark = SparkSession.builder.appName("decision_tree_classification_example").getOrCreate()
# $example on$
# Load the data stored in LIBSVM format as a DataFrame.
- data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
+ data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
# Automatically identify categorical features, and index them.
# We specify maxCategories so features with > 4 distinct values are treated as continuous.
@@ -69,3 +68,5 @@ if __name__ == "__main__":
# summary only
print(treeModel)
# $example off$
+
+ spark.stop()