diff options
author | Davies Liu <davies@databricks.com> | 2015-02-04 15:55:09 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-02-04 15:55:09 -0800 |
commit | dc101b0e4e23dffddbc2f70d14a19fae5d87a328 (patch) | |
tree | e436271c351a64caa4727661cd6143ba6e415fa6 /sql | |
parent | e0490e271d078aa55d7c7583e2ba80337ed1b0c4 (diff) | |
download | spark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.tar.gz spark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.tar.bz2 spark-dc101b0e4e23dffddbc2f70d14a19fae5d87a328.zip |
[SPARK-5577] Python udf for DataFrame
Author: Davies Liu <davies@databricks.com>
Closes #4351 from davies/python_udf and squashes the following commits:
d250692 [Davies Liu] fix conflict
34234d4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf
440f769 [Davies Liu] address comments
f0a3121 [Davies Liu] track life cycle of broadcast
f99b2e1 [Davies Liu] address comments
462b334 [Davies Liu] Merge branch 'master' of github.com:apache/spark into python_udf
7bccc3b [Davies Liu] python udf
58dee20 [Davies Liu] clean up
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 19 | ||||
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala | 27 |
2 files changed, 44 insertions, 2 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index ddce77deb8..4c2aeadae9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -128,7 +128,6 @@ trait Column extends DataFrame { */ def unary_! : Column = exprToColumn(Not(expr)) - /** * Equality test. * {{{ @@ -166,7 +165,7 @@ trait Column extends DataFrame { * * // Java: * import static org.apache.spark.sql.Dsl.*; - * df.filter( not(col("colA").equalTo(col("colB"))) ); + * df.filter( col("colA").notEqual(col("colB")) ); * }}} */ def !== (other: Any): Column = constructColumn(other) { o => @@ -174,6 +173,22 @@ trait Column extends DataFrame { } /** + * Inequality test. + * {{{ + * // Scala: + * df.select( df("colA") !== df("colB") ) + * df.select( !(df("colA") === df("colB")) ) + * + * // Java: + * import static org.apache.spark.sql.Dsl.*; + * df.filter( col("colA").notEqual(col("colB")) ); + * }}} + */ + def notEqual(other: Any): Column = constructColumn(other) { o => + Not(EqualTo(expr, o.expr)) + } + + /** * Greater than. * {{{ * // Scala: The following selects people older than 21. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index 8d7c2a1b83..c60d407094 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -17,7 +17,13 @@ package org.apache.spark.sql +import java.util.{List => JList, Map => JMap} + +import org.apache.spark.Accumulator +import org.apache.spark.api.python.PythonBroadcast +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType /** @@ -37,3 +43,24 @@ case class UserDefinedFunction(f: AnyRef, dataType: DataType) { Column(ScalaUdf(f, dataType, exprs.map(_.expr))) } } + +/** + * A user-defined Python function. To create one, use the `pythonUDF` functions in [[Dsl]]. + * This is used by Python API. + */ +private[sql] case class UserDefinedPythonFunction( + name: String, + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]], + dataType: DataType) { + + def apply(exprs: Column*): Column = { + val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars, + accumulator, dataType, exprs.map(_.expr)) + Column(udf) + } +} |