aboutsummaryrefslogtreecommitdiff
path: root/examples/src
diff options
context:
space:
mode:
authorJeremyNixon <jnixon2@gmail.com>2016-03-10 09:09:56 +0200
committerNick Pentreath <nick.pentreath@gmail.com>2016-03-10 09:18:15 +0200
commit3e3c3d58d8d42b42e930d42eb70b0e84d02967eb (patch)
tree9383504010bacd44331a936b1c349db7cb90af15 /examples/src
parent8bcad28a5a6788c96bf1c302eb6f18d37b798b03 (diff)
downloadspark-3e3c3d58d8d42b42e930d42eb70b0e84d02967eb.tar.gz
spark-3e3c3d58d8d42b42e930d42eb70b0e84d02967eb.tar.bz2
spark-3e3c3d58d8d42b42e930d42eb70b0e84d02967eb.zip
[SPARK-13706][ML] Add Python Example for Train Validation Split
## What changes were proposed in this pull request? This pull request adds a python example for train validation split. ## How was this patch tested? This was style tested through lint-python, generally tested with ./dev/run-tests, and run in notebook and shell environments. It was viewed in docs locally with jekyll serve. This contribution is my original work and I license it to Spark under its open source license. Author: JeremyNixon <jnixon2@gmail.com> Closes #11547 from JeremyNixon/tvs_example.
Diffstat (limited to 'examples/src')
-rw-r--r--examples/src/main/python/ml/train_validation_split.py68
1 files changed, 68 insertions, 0 deletions
diff --git a/examples/src/main/python/ml/train_validation_split.py b/examples/src/main/python/ml/train_validation_split.py
new file mode 100644
index 0000000000..161a200c61
--- /dev/null
+++ b/examples/src/main/python/ml/train_validation_split.py
@@ -0,0 +1,68 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark import SparkContext
+# $example on$
+from pyspark.ml.evaluation import RegressionEvaluator
+from pyspark.ml.regression import LinearRegression
+from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit
+from pyspark.sql import SQLContext
+# $example off$
+
+"""
+This example demonstrates applying TrainValidationSplit to split data
+and preform model selection.
+Run with:
+
+ bin/spark-submit examples/src/main/python/ml/train_validation_split.py
+"""
+
+if __name__ == "__main__":
+ sc = SparkContext(appName="TrainValidationSplit")
+ sqlContext = SQLContext(sc)
+ # $example on$
+ # Prepare training and test data.
+ data = sqlContext.read.format("libsvm")\
+ .load("data/mllib/sample_linear_regression_data.txt")
+ train, test = data.randomSplit([0.7, 0.3])
+ lr = LinearRegression(maxIter=10, regParam=0.1)
+
+ # We use a ParamGridBuilder to construct a grid of parameters to search over.
+ # TrainValidationSplit will try all combinations of values and determine best model using
+ # the evaluator.
+ paramGrid = ParamGridBuilder()\
+ .addGrid(lr.regParam, [0.1, 0.01]) \
+ .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])\
+ .build()
+
+ # In this case the estimator is simply the linear regression.
+ # A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
+ tvs = TrainValidationSplit(estimator=lr,
+ estimatorParamMaps=paramGrid,
+ evaluator=RegressionEvaluator(),
+ # 80% of the data will be used for training, 20% for validation.
+ trainRatio=0.8)
+
+ # Run TrainValidationSplit, and choose the best set of parameters.
+ model = tvs.fit(train)
+ # Make predictions on test data. model is the model with combination of parameters
+ # that performed best.
+ prediction = model.transform(test)
+ for row in prediction.take(5):
+ print(row)
+ # $example off$
+ sc.stop()