aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python/ml/train_validation_split.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/python/ml/train_validation_split.py')
-rw-r--r--examples/src/main/python/ml/train_validation_split.py10
1 files changed, 4 insertions, 6 deletions
diff --git a/examples/src/main/python/ml/train_validation_split.py b/examples/src/main/python/ml/train_validation_split.py
index 161a200c61..2e43a0f8ae 100644
--- a/examples/src/main/python/ml/train_validation_split.py
+++ b/examples/src/main/python/ml/train_validation_split.py
@@ -15,13 +15,12 @@
# limitations under the License.
#
-from pyspark import SparkContext
# $example on$
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit
-from pyspark.sql import SQLContext
# $example off$
+from pyspark.sql import SparkSession
"""
This example demonstrates applying TrainValidationSplit to split data
@@ -32,11 +31,10 @@ Run with:
"""
if __name__ == "__main__":
- sc = SparkContext(appName="TrainValidationSplit")
- sqlContext = SQLContext(sc)
+ spark = SparkSession.builder.appName("TrainValidationSplit").getOrCreate()
# $example on$
# Prepare training and test data.
- data = sqlContext.read.format("libsvm")\
+ data = spark.read.format("libsvm")\
.load("data/mllib/sample_linear_regression_data.txt")
train, test = data.randomSplit([0.7, 0.3])
lr = LinearRegression(maxIter=10, regParam=0.1)
@@ -65,4 +63,4 @@ if __name__ == "__main__":
for row in prediction.take(5):
print(row)
# $example off$
- sc.stop()
+ spark.stop()