Use UDF in Spark DataFrame

It is very convenient to create, register, and use user define functions with data. In addition, the recent release of Apache Spark also supports writing user-defined aggregation functions UDAF. Below is a short piece of code to demonstrate creating and using UDF in spark shell. It is quite often that data engineers massage the data and create necessary toolkit functions. The data analyst or business analyst start using them through either Hive over spark, Spark SQL, or other BI tools.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
//Create a case class struncture to hold the data for demo. 
//You can also directly convert Hive table into DF,
//such as sqlContext.table("hive_table_name")
case class StockPrice(symbol:String, price:Seq[Double])

//create the demo data
val data=sc.parallelize(
Seq(
StockPrice("APPL", Seq(93.5, 95.6, 102.7)),
StockPrice("GOOG", Seq(604.5, 603.7, 614.1)),
StockPrice("BABA", Seq(64.8, 95.2, 96.0))
)
)

//Convert the data to a DataFrame
val df=data.toDF("symbol", "price")

//Register the results to a temp table
//Alternativelly, you can materilize the data in Hive by
//df.saveAsTable("stock_price")
df.registerTempTable("stock_price")

//Create a function that get's the average race time
def avgStockPrice(price:Double*) = {
var totalPriec=0.0
price.foreach(x => totalPriec+=x)
totalPriec/price.size
}

//Register that function with the SQLContext's UDF Registry

//- means to register as partital applied function
sqlContext.udf.register("avgStockPrice", avgStockPrice _)

//Use the UDF in a Query
sqlContext.sql(
"select symbol, avgStockPrice(price) avg_stock_price
from stock_price").collect().foreach(println)

The result is as follows.

1
2
3
[APPL,97.26666666666667]
[GOOG,607.4333333333334]
[BABA,85.33333333333333]