Fork me on GitHub

sparksql详解

Spark SQL的Dataset基本操作

https://www.toutiao.com/i6631318012546793992/
版本:Spark 2.3.1,hadoop-2.7.4,jdk1.8为例讲解

基本简介和准备工作

Dataset接口是在spark 1.6引入的,受益于RDD(强类型,可以使用强大的lambda函数),同时也可以享受Spark SQL优化执行引擎的优点。Dataset的可以从jvm 对象创建,然后就可以使用转换函数(map,flatmap,filter等)。

1.6版本之前常用的Dataframe是一种特殊的Dataset,也即是

1
type DataFrame = Dataset[Row]

结构化数据文件(json,csv,orc等),hive表,外部数据库,已有RDD。
下面进行测试,大家可能都会说缺少数据,实际上Spark源码里跟我们提供了丰富的测试数据。源码的examples路径下:examples/src/main/resources。

首先要创建一个SparkSession:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
val sparkConf = new SparkConf()
.setAppName(this.getClass.getName)
.setMaster("local[*]")
.set("yarn.resourcemanager.hostname", "localhost")
// executor的实例数
.set("spark.executor.instances","2")
.set("spark.default.parallelism","4")
// sql shuffle的并行度,由于是本地测试,所以设置较小值,避免产生过多空task,实际上要根据生产数据量进行设置。
.set("spark.sql.shuffle.partitions","4")
.setJars(List("/Users/meitu/Desktop/sparkjar/bigdata.jar"
,"/opt/jars/spark-streaming-kafka-0-10_2.11-2.3.1.jar"
,"/opt/jars/kafka-clients-0.10.2.2.jar"
,"/opt/jars/kafka_2.11-0.10.2.2.jar"
))

val spark = SparkSession
.builder()
.config(sparkConf)
.getOrCreate()

创建dataset

1
2
3
4
5
6
val sales = spark.createDataFrame(
Seq( ("Warsaw", 2016, 100), ("Warsaw", 2017, 200), ("Warsaw", 2015, 100), ("Warsaw", 2017, 200), ("Beijing", 2017, 200), ("Beijing", 2016, 200),
("Beijing", 2015, 200), ("Beijing", 2014, 200), ("Warsaw", 2014, 200),
("Boston", 2017, 50), ("Boston", 2016, 50), ("Boston", 2015, 50),
("Boston", 2014, 150)))
.toDF("city", "year", "amount")

使用函数的时候要导入包:

1
import org.apache.spark.sql.functions.{col,expr}

select

列名称可以是字符串,这种形式无法对列名称使用表达式进行逻辑操作。
使用col函数,可以直接对列进行一些逻辑操作。

1
2
sales.select("city","year","amount").show(1)
sales.select(col("city"),col("amount")+1).show(1)

selectExpr

参数是字符串,且直接可以使用表达式。
也可以使用select+expr函数来替代。

1
2
sales.selectExpr("city","year as date","amount+1").show(10)
sales.select(expr("city"),expr("year as date"),expr("amount+1")).show(10)

filter

参数可以是与col结合的表达式,参数类型为row返回值为boolean的函数,字符串表达式。

1
2
3
sales.filter(col("amount")>150).show()
sales.filter(row=>{ row.getInt(2)>150}).show(10)
sales.filter("amount > 150 ").show(10)

where

类似于fliter,参数可以是与col函数结合的表达式也可以是直接使用表达式字符串。

1
2
sales.where(col("amount")>150).show()
sales.where("amount > 150 ").show()

group by

主要是以count和agg聚合函数为例讲解groupby函数。

1
2
sales.groupBy("city").count().show(10)
sales.groupBy(col("city")).agg(sum("amount").as("total")).show(10)

union

两个dataset的union操作这个等价于union all操作,所以要实现传统数据库的union操作,需要在其后使用distinct进行去重操作。

1
sales.union(sales).groupBy("city").count().show()

join

join操作相对比较复杂,具体如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
// 相同的列进行join
sales.join(sales,"city").show(10)
// 多列join
sales.join(sales,Seq("city","year")).show()
/* 指定join类型, join 类型可以选择:
`inner`, `cross`, `outer`, `full`, `full_outer`,
`left`, `left_outer`, `right`, `right_outer`,
`left_semi`, `left_anti`.
*/
// 内部join
sales.join(sales,Seq("city","year"),"inner").show()
/* join条件 :
可以在join方法里放入join条件,
也可以使用where,这两种情况都要求字段名称不一样。
*/
sales.join(sales, col("city").alias("city1") === col("city")).show() sales.join(sales).where(col("city").alias("city1") === col("city")).show()
/*
dataset的self join 此处使用where作为条件,
需要增加配置.set("spark.sql.crossJoin.enabled","true")
也可以加第三个参数,join类型,可以选择如下:
`inner`, `cross`, `outer`, `full`, `full_outer`, `left`,
`left_outer`, `right`, `right_outer`, `left_semi`, `left_anti`
*/
sales.join(sales,sales("city") === sales("city")).show() sales.join(sales).where(sales("city") === sales("city")).show()
/* joinwith,可以指定第三个参数,join类型,
类型可以选择如下:
`inner`, `cross`, `outer`, `full`, `full_outer`, `left`, `left_outer`,
`right`, `right_outer`。
*/
sales.joinWith(sales,sales("city") === sales("city"),"inner").show()

输出结果:

order by

orderby 全局有序,其实用的还是sort

1
2
sales.orderBy(col("year").desc,col("amount").asc).show()
sales.orderBy("city","year").show()

sort

全局排序,直接替换掉8小结的orderby即可。

sortwithinpartition

在分区内部进行排序,局部排序。

1
2
sales.sortWithinPartitions(col("year").desc,col("amount").asc).show()
sales.sortWithinPartitions("city","year").show()

可以看到,city为背景的应该是分配到不同的分区,然后每个分区内部year都是有序的。

withColumn

1
2
3
4
5
6
7
8
9
10
11
/* withColumn 
假如列,存在就替换,不存在新增
withColumnRenamed
对已有的列进行重命名
*/
//相当于给原来amount列,+1
sales.withColumn("amount",col("amount")+1).show()
// 对amount列+1,然后将值增加到一个新列 amount1
sales.withColumn("amount1",col("amount")+1).show()
// 将amount列名,修改为amount1
sales.withColumnRenamed("amount","amount1").show()

foreach

这个跟rdd的foreach一样,元素类型是row。

1
2
3
sales.foreach(row=>{ 
println(row.getString(0))
})

foreachPartition

跟RDD的foreachPartition一样,针对分区进行计算,对于输出到数据库,kafka等数据相对于使用foreach可以大量减少连接数。

1
2
3
4
5
6
7
8
sales.foreachPartition(partition=>{ 
//打开数据库链接等
partition.foreach(each=>{
println(each.getString(0))
//插入数据库
})
//关闭数据库链接
})

distinct

针对dataset的行去重,返回的是所有行都不重复的dataset。

1
sales.distinct().show(10)

dropDuplicates

这个适用于dataset有唯一的主键,然后对主键进行去重。

1
2
3
4
val before = sales.count()
val after = sales.dropDuplicates("city").count()
println("before ====> " +before)
println("after ====> "+after)

drop

删除一列,或者多列,这是一个变参数算子。

1
2
3
4
5
sales.drop("city").show()
打印出来schema信息如下:
root
|-- year: integer (nullable = false)
|-- amount: integer (nullable = false)

printSchema

输出dataset的schema信息

1
2
3
4
5
6
7
8
9
10
11
sales.printSchema()

输出结果如下:

root

|-- city: string (nullable = true)

|-- year: integer (nullable = false)

|-- amount: integer (nullable = false)

explain()

打印执行计划,这个便于调试,了解spark sql引擎的优化执行的整个过程

1
2
3
4
5
6
sales.orderBy(col("year").desc,col("amount").asc).explain()
执行计划输出如下:
== Physical Plan ==
*(1) Sort [year#7 DESC NULLS LAST, amount#8 ASC NULLS FIRST], true, 0
+- Exchange rangepartitioning(year#7 DESC NULLS LAST, amount#8 ASC NULLS FIRST, 3)
+- LocalTableScan [city#6, year#7, amount#8]

Spark SQL入门到精通之第二篇Dataset的复杂操作

本文是Spark SQL入门到精通系列第二弹,数据仓库常用的操作:
cube,rollup,pivot操作。

cube

简单的理解就是维度及度量组成的数据体。

1
2
3
4
sales.cube("city","year") 
.agg(sum("amount"))
.sort(col("city").desc_nulls_first,col("year").desc_nulls_first)
.show()

举个简单的例子,上面的纬度city,year两个列是纬度,然后amount是要进行聚合的度量。
实际上就相当于,(year,city),(year),(city),() 分别分组然后对amount求sum,最终输出结果,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
val city_year = sales
.groupBy("city","year").
agg(sum("amount"))
val city = sales.groupBy("city")
.agg(sum("amount") as "amount")
.select(col("city"), lit(null) as "year", col("amount"))
val year = sales.groupBy("year")
.agg(sum("amount") as "amount")
.select( lit(null) as "city",col("year"), col("amount"))
val none = sales .groupBy()
.agg(sum("amount") as "amount")
.select(lit(null) as "city", lit(null) as "year", col("amount"))
city_year.union(city).union(year).union(none)
.sort(desc_nulls_first("city"), desc_nulls_first("year"))
.show()

rollup

这里也是以案例开始,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
val expenses = spark.createDataFrame(Seq( 
((2012, Month.DECEMBER, 12), 5),
((2016, Month.AUGUST, 13), 10),
((2017, Month.MAY, 27), 15))
.map { case ((yy, mm, dd), a) => (LocalDate.of(yy, mm, dd), a) }
.map { case (d, a) => (d.toString, a) }
.map { case (d, a) => (Date.valueOf(d), a) }).toDF("date", "amount")
// rollup time!
val res = expenses
.rollup(year(col("date")) as "year", month(col("date")) as "month")
.agg(sum("amount") as "amount")
.sort(col("year").asc_nulls_last, col("month").asc_nulls_last)
.show()

这个等价于分别对(year,month),(year),()进行 groupby 对amount求sum,然后再进行union操作,也即是可以按照下面的实现:

1
2
3
4
5
6
7
8
9
10
11
12
val year_month = expenses
.groupBy(year(col("date")) as "year", month(col("date")) as "month")
.agg(sum("amount") as "amount")
val yearOnly = expenses.groupBy(year(col("date")) as "year")
.agg(sum("amount") as "amount")
.select(col("year"), lit(null) as "month", col("amount"))
val none = expenses.groupBy()
.agg(sum("amount") as "amount")
.select(lit(null) as "year", lit(null) as "month", col("amount"))
year_month.union(yearOnly).union(none)
.sort(col("year").asc_nulls_last, col("month").asc_nulls_last)
.show()

pivot

旋转操作
https://mp.weixin.qq.com/s/Jky2q3FW5wqQ-prpsW2OEw
Pivot 算子是 spark 1.6 版本开始引入的,在 spark2.4版本中功能做了增强,还是比较强大的,做过数据清洗ETL工作的都知道,行列转换是一个常见的数据整理需求。spark 中的Pivot 可以根据枢轴点(Pivot Point) 把多行的值归并到一行数据的不同列,这个估计不太好理解,我们下面使用例子说明,看看pivot 这个算子在处理复杂数据时候的威力。

1
2
3
4
5
6
7
8
9
10
11
12
13
val sales = spark.createDataFrame(Seq( 
("Warsaw", 2016, 100,"Warsaw"),
("Warsaw", 2017, 200,"Warsaw"),
("Warsaw", 2016, 100,"Warsaw"),
("Warsaw", 2017, 200,"Warsaw"),
("Boston", 2015, 50,"Boston"),
("Boston", 2016, 150,"Boston"),
("Toronto", 2017, 50,"Toronto")))
.toDF("city", "year", "amount","test")
sales.groupBy("year")
.pivot("city",Seq("Warsaw","Boston","Toronto"))
.agg(sum("amount") as "amount")
.show()

思路就是首先对year分组,然后旋转city字段,只取:
“Warsaw”,”Boston”,”Toronto”
然后对amount进行聚合操作
案例:
使用Pivot 来统计天气走势。
下面是西雅图的天气数据表,每行代表一天的天气最高值:

1
2
3
4
5
6
7
8
9
10
  Date         TempF)   
07-22-2018 86
07-23-2018 90
07-24-2018 91
07-25-2018 92
07-26-2018 92
07-27-2018 88
07-28-2018 85
07-29-2018 94
07-30-2018 89

如果我们想看下最近几年的天气走势,如果这样一天一行数据,是很难看出趋势来的,最直观的方式是 按照年来分行,然后每一列代表一个月的平均天气,这样一行数据,就可以看到这一年12个月的一个天气走势,下面我们使用 pivot 来构造这样一个查询结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
SELECT 
*
FROM
(
SELECT
year(date) year,
month(date) month,
temp
FROM
high_temps
WHERE
date between DATE '2015-01-01'
and DATE '2018-08-31'
) PIVOT (
CAST (
AVG(temp) as DECIMAL(4, 1)
) FOR month in (
1 JAN, 2 FEB, 3 MAR, 4 APR, 5 MAY, 6 JUN, 7 JUL,
8 AUG, 9 SEP, 10 OCT, 11 NOV, 12 DEC
)
)
ORDER BY
year DESC;

结果如下图:

是不是很直观,第一行就代表 2018 年,从1月到12月的平均天气,能看出一年的天气走势。
我们来看下这个 sql 是怎么玩的,首先是一个 子查询语句,我们最终关心的是 年份,月份,和最高天气值,所以先使用子查询对原始数据进行处理,从日期里面抽取出来年份和月份。

下面在子查询的结果上使用 pivot 语句,pivot 第一个参数是一个聚合语句,这个代表聚合出来一个月30天的一个平均气温,第二个参数是 FOR month,这个是指定以哪个列为枢轴列,第三个 In 子语句指定我们需要进行行列转换的具体的 枢轴点(Pivot Point)的值,上面的例子中 1到12月份都包含了,而且给了一个别名,如果只指定 1到6月份,结果就如下了:

上面sql语句里面有个特别的点需要注意, 就是聚合的时候有个隐含的维度字段,就是 年份,按理来讲,我们没有写 group-by year, 为啥结果表里面不同年份区分在了不同的行,原因是,FORM 子查询出来的每行有 3个列, year,month,tmp,如果一个列既不出现在进行聚合计算的列中(temp 是聚合计算的列), 也不作为枢轴列 , 就会作为聚合的时候一个隐含的维度。我们的例子中算平均值聚合操作的维度是 (year, month),一个是隐含维度,一个是 枢轴列维度, 这一点一定要注意,如果不需要按照 year 来区分,FORM 查询的时候就不要加上这个列。
指定多个聚合语句
上文中只有一个聚合语句,就是计算平均天气,其实是可以加多个聚合语句的,比如我们需要看到 7,8,9 月份每个月的最大气温和平均气温,就可以用以下SQL语句。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
SELECT 
*
FROM
(
SELECT
year(date) year,
month(date) month,
temp
FROM
high_temps
WHERE
date between DATE '2015-01-01'
and DATE '2018-08-31'
) PIVOT (
CAST (
AVG(temp) as DECIMAL(4, 1)
) avg,
max(temp) max FOR month in (6 JUN, 7 JUL, 8 AUG, 9 SEP)
)
ORDER BY
year DESC;

上文中指定了两个聚合语句,查询后, 枢轴点(Pivot Point) 和 聚合语句 的笛卡尔积作为结果的不同列,也就是 _, 看下图:

聚合列(Grouping Columns)和 枢轴列(Pivot Columns)的不同之处
现在假如我们有西雅图每天的最低温数据,我们需要把最高温和最低温放在同一张表里面看对比着看.

1
2
3
4
5
6
7
8
9
Date	   TempF)
… …
08-01-2018 59
08-02-2018 58
08-03-2018 59
08-04-2018 58
08-05-2018 59
08-06-2018 59
… …

我们使用 UNION ALL 把两张表做一个合并:

1
2
3
4
5
6
7
SELECT date, temp, 'H' as flag 
FROM
high_temps
UNION ALL
SELECT date, temp, 'L' as flag
FROM
low_temps;

现在使用 pivot 来进行处理:

1
2
3
4
5
6
7
8
9
10
11
12
SELECT * FROM (
SELECT date,temp,'H' as flag
FROM high_temps
UNION ALL
SELECT date,temp,'L' as flag
FROM low_temps
)
WHERE date BETWEEN DATE '2015-01-01' AND DATE '2018-08-31'
PIVOT(
CAST(avg(temp) as DECIMAL(4,1))
FOR month in (6 JUN,7 JUL,8 AUG , 9 SEP)
) ORDER BY year DESC ,`H/L` ASC;

我们统计了 4年中 7,8,9 月份最低温和最高温的平均值,这里要注意的是,我们把 year 和 一个最低最高的标记(H/L)都作为隐含维度,不然算出来的就是最低最高温度在一起的平均值了。结果如下图:

上面的查询中,我们把最低最高的标记(H/L)也作为了一个隐含维度,group-by 的维度就变成了 (year, H/L, month), 但是year 和 H/L 体现在了不同行上面,month 体现在了不同的列上面,这就是 聚合列(Grouping Columns)和 枢轴列(Pivot Columns)的不同之处。

再接再厉,我们把 H/L 作为 枢轴列(Pivot Columns) 进行查询:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
SELECT * FROM (
SELECT date,temp,'H' as flag
FROM high_temps
UNION ALL
SELECT date,temp,'L' as flag
FROM low_temps
)
WHERE date BETWEEN DATE '2015-01-01' AND DATE '2018-08-31'
PIVOT(
CAST(avg(temp) as DECIMAL(4,1))
FOR (month,flag) in (
(6,'H') JUN_hi,(6,'L') JUN_lo,
(7,'H') JUl_hi,(7,'L') JUl_lo,
(8,'H') AUG_hi,(8,'L') AUG_lo,
(9,'H') SEP_hi,(9,'L') SEP_lo
)
) ORDER BY year DESC;

结果的展现方式就和上面的不同了,虽然每个单元格值都是相同的,但是把 H/L 和 month 笛卡尔积作为列,H/L 维度体现在了列上面了:

源码:Spark SQL 分区特性第一弹

常见RDD分区

Spark Core 中的RDD的分区特性大家估计都很了解,这里说的分区特性是指从数据源读取数据的第一个RDD或者Dataset的分区,而后续再介绍转换过程中分区的变化。
举几个浪尖在星球里分享比较多的例子,比如:
Spark Streaming 与kafka 结合 DirectDstream 生成的微批RDD(kafkardd)分区数和kafka分区数一样。
Spark Streaming 与kafka结合 基于receiver的方式,生成的微批RDD(blockRDD),分区数就是block数。
普通的文件RDD,那么分可分割和不可分割,通常不可分割的分区数就是文件数。可分割需要计算而且是有条件的,在星球里分享过了。
这些都很简单,那么今天咱们要谈的是Spark DataSet的分区数的决定因素。

准备数据

首先是由Seq数据集合生成一个Dataset

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
val sales = spark.createDataFrame(Seq(
("Warsaw", 2016, 110),
("Warsaw", 2017, 10),
("Warsaw", 2015, 100),
("Warsaw", 2015, 50),
("Warsaw", 2015, 80),
("Warsaw", 2015, 100),
("Warsaw", 2015, 130),
("Warsaw", 2015, 160),
("Warsaw", 2017, 200),
("Beijing", 2017, 100),
("Beijing", 2016, 150),
("Beijing", 2015, 50),
("Beijing", 2015, 30),
("Beijing", 2015, 10),
("Beijing", 2014, 200),
("Beijing", 2014, 170),
("Boston", 2017, 50),
("Boston", 2017, 70),
("Boston", 2017, 110),
("Boston", 2017, 150),
("Boston", 2017, 180),
("Boston", 2016, 30),
("Boston", 2015, 200),
("Boston", 2014, 20)
)).toDF("city", "year", "amount")

将Dataset存处为partquet格式的hive表,分两种情况:
用city和year字段分区

1
sales.write.partitionBy("city","year").mode(SaveMode.Overwrite).saveAsTable("ParquetTestCityAndYear")

用city字段分区

1
sales.write.partitionBy("city").mode(SaveMode.Overwrite).saveAsTable("ParquetTestCity")

读取数据采用的是

1
val res = spark.read.parquet("/user/hive/warehouse/parquettestcity")

直接展示,结果发现结果会随着spark.default.parallelism变化而变化。文章里只读取city字段分区的数据,特点就是只有单个分区字段。

1. spark.default.parallelism =40

Dataset的分区数是由参数:
目录数和生成的FileScanRDD的分区数分别数下面截图的第一行和第二行
这个分区数目正好是文件数,那么假如不了解细节的话,肯定会认为分区数就是由文件数决定的,其实不然。

2. spark.default.parallelism =4

Dataset的分区数是由参数:

1
println("partition size = "+res.rdd.partitions.length)

目录数和生成的FileScanRDD的分区数分别数下面截图的第一行和第二行。

那么数据源生成的Dataset的分区数到底是如何决定的呢?
我们这种情况,我只能告诉你是由下面的函数在生成FileScanRDD的时候计算得到的,具体计算细节可以仔细阅读该函数。该函数是类FileSourceScanExec的方法。

那么数据源生成的Dataset的分区数到底是如何决定的呢?
我们这种情况,我只能告诉你是由下面的函数在生成FileScanRDD的时候计算得到的,具体计算细节可以仔细阅读该函数。该函数是类FileSourceScanExec的方法。
那么数据源生成的Dataset的分区数到底是如何决定的呢?
我们这种情况,我只能告诉你是由下面的函数在生成FileScanRDD的时候计算得到的,具体计算细节可以仔细阅读该函数。该函数是类FileSourceScanExec的方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
private def createNonBucketedReadRDD(
readFile: (PartitionedFile) => Iterator[InternalRow],
selectedPartitions: Seq[PartitionDirectory],
fsRelation: HadoopFsRelation): RDD[InternalRow] = {
/*
selectedPartitions 的大小代表目录数目
*/
println("selectedPartitions.size : "+ selectedPartitions.size)
val defaultMaxSplitBytes =
fsRelation.sparkSession.sessionState.conf.filesMaxPartitionBytes
val openCostInBytes = fsRelation.sparkSession.sessionState.conf.filesOpenCostInBytes

// spark.default.parallelism
val defaultParallelism = fsRelation.sparkSession.sparkContext.defaultParallelism

// 计算文件总大小,单位字节数
val totalBytes = selectedPartitions.flatMap(_.files.map(_.getLen + openCostInBytes)).sum

//计算平均每个并行度读取数据大小
val bytesPerCore = totalBytes / defaultParallelism

// 首先spark.sql.files.openCostInBytes 该参数配置的值和bytesPerCore 取最大值
// 然后,比较spark.sql.files.maxPartitionBytes 取小者
val maxSplitBytes = Math.min(defaultMaxSplitBytes, Math.max(openCostInBytes, bytesPerCore))
logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " +
s"open cost is considered as scanning $openCostInBytes bytes.")

// 这对目录遍历
val splitFiles = selectedPartitions.flatMap { partition =>
partition.files.flatMap { file =>
val blockLocations = getBlockLocations(file)

//判断文件类型是否支持分割,以parquet为例,是支持分割的
if (fsRelation.fileFormat.isSplitable(
fsRelation.sparkSession, fsRelation.options, file.getPath)) {

// eg. 0 until 2不包括 2。相当于
// println(0 until(10) by 3) 输出 Range(0, 3, 6, 9)
(0L until file.getLen by maxSplitBytes).map { offset =>

// 计算文件剩余的量
val remaining = file.getLen - offset

// 假如剩余量不足 maxSplitBytes 那么就剩余的作为一个分区
val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining

// 位置信息
val hosts = getBlockHosts(blockLocations, offset, size)
PartitionedFile(
partition.values, file.getPath.toUri.toString, offset, size, hosts)
}
} else {
// 不可分割的话,那即是一个文件一个分区
val hosts = getBlockHosts(blockLocations, 0, file.getLen)
Seq(PartitionedFile(
partition.values, file.getPath.toUri.toString, 0, file.getLen, hosts))
}
}
}.toArray.sortBy(_.length)(implicitly[Ordering[Long]].reverse)

val partitions = new ArrayBuffer[FilePartition]
val currentFiles = new ArrayBuffer[PartitionedFile]
var currentSize = 0L

/** Close the current partition and move to the next. */
def closePartition(): Unit = {
if (currentFiles.nonEmpty) {
val newPartition =
FilePartition(
partitions.size,
currentFiles.toArray.toSeq) // Copy to a new Array.
partitions += newPartition
}
currentFiles.clear()
currentSize = 0
}

// Assign files to partitions using "Next Fit Decreasing"
splitFiles.foreach { file =>
if (currentSize + file.length > maxSplitBytes) {
closePartition()
}
// Add the given file to the current partition.
currentSize += file.length + openCostInBytes
currentFiles += file
}
closePartition()

println("FileScanRDD partitions size : "+partitions.size)
new FileScanRDD(fsRelation.sparkSession, readFile, partitions)
}

找到一列中的中位数

https://spark.apache.org/docs/latest/api/sql/index.html
df函数: approxQuantile
sql函数: percentile_approx

自定义数据源

ServiceLoader
https://mp.weixin.qq.com/s/QpYvqJpw7TnFAY8rD6Jf3w
ServiceLoader是SPI的是一种实现,所谓SPI,即Service Provider Interface,用于一些服务提供给第三方实现或者扩展,可以增强框架的扩展或者替换一些组件。
要配置在相关项目的固定目录下:
resources/META-INF/services/接口全称。
这个在大数据的应用中颇为广泛,比如Spark2.3.1 的集群管理器插入:
SparkContext类

1
2
3
4
5
6
7
8
9
10
11
  private def getClusterManager(url: String): Option[ExternalClusterManager] = {
val loader = Utils.getContextOrSparkClassLoader
val serviceLoaders =
ServiceLoader.load(classOf[ExternalClusterManager], loader).asScala.filter(_.canCreate(url))
if (serviceLoaders.size > 1) {
throw new SparkException(
s"Multiple external cluster managers registered for the url $url: $serviceLoaders")
}
serviceLoaders.headOption
}
}

配置是在

spark sql数据源的接入,新增数据源插入的时候可以采用这种方式,要实现的接口是DataSourceRegister。
简单测试
首先实现一个接口

1
2
3
4
5
6
7
package bigdata.spark.services;

public interface DoSomething {
//可以制定实现类名加载
public String shortName();
public void doSomeThing();
}

然后将接口配置在resources/META-INF/services/
bigdata.spark.services.DoSomething文件
内容:
实现该接口

1
2
3
4
5
6
7
8
9
10
11
12
13
package bigdata.spark.services;

public class SayHello implements DoSomething {
@Override
public String shortName() {
return "SayHello";
}

@Override
public void doSomeThing() {
System.out.println("hello !!!");
}
}

测试

1
2
3
4
5
6
7
8
9
10
11
12
13
package bigdata.spark.services;
import java.util.ServiceLoader;
public class test {
static ServiceLoader<DoSomething> loader = ServiceLoader.load(DoSomething.class);
public static void main(String[] args){
for(DoSomething sayhello : loader){
//要加载的类名称我们可以制定
if(sayhello.shortName().equalsIgnoreCase("SayHello")){
sayhello.doSomeThing();
}
}
}
}

这个主要是为讲自定义数据源作准备。

https://articles.zsxq.com/id_702s32f46zet.html
首先要搞明白spark是如何支持多数据源的,昨天说了是通过serverloader加载的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider: String): Class[_] = {
val provider1 = backwardCompatibilityMap.getOrElse(provider, provider)
// 指定路径加载,默认加载类名是DefaultSource
val provider2 = s"$provider1.DefaultSource"
val loader = Utils.getContextOrSparkClassLoader
// ServiceLoader加载
val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader)

try {
serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider1)).toList match {
// the provider format did not match any given registered aliases
case Nil =>
try {
Try(loader.loadClass(provider1)).orElse(Try(loader.loadClass(provider2))) match {
case Success(dataSource) =>
// Found the data source using fully qualified path
dataSource
case Failure(error) =>
if (provider1.toLowerCase == "orc" ||
provider1.startsWith("org.apache.spark.sql.hive.orc")) {
throw new AnalysisException(
"The ORC data source must be used with Hive support enabled")
} else if (provider1.toLowerCase == "avro" ||
provider1 == "com.databricks.spark.avro") {
throw new AnalysisException(
s"Failed to find data source: ${provider1.toLowerCase}. Please find an Avro " +
"package at http://spark.apache.org/third-party-projects.html")
} else {
throw new ClassNotFoundException(
s"Failed to find data source: $provider1. Please find packages at " +
"http://spark.apache.org/third-party-projects.html",
error)
}
}
} catch {
case e: NoClassDefFoundError => // This one won't be caught by Scala NonFatal
// NoClassDefFoundError's class name uses "/" rather than "." for packages
val className = e.getMessage.replaceAll("/", ".")
if (spark2RemovedClasses.contains(className)) {
throw new ClassNotFoundException(s"$className was removed in Spark 2.0. " +
"Please check if your library is compatible with Spark 2.0", e)
} else {
throw e
}
}
case head :: Nil =>
// there is exactly one registered alias
head.getClass
case sources =>
// There are multiple registered aliases for the input
sys.error(s"Multiple sources found for $provider1 " +
s"(${sources.map(_.getClass.getName).mkString(", ")}), " +
"please specify the fully qualified class name.")
}
} catch {
case e: ServiceConfigurationError if e.getCause.isInstanceOf[NoClassDefFoundError] =>
// NoClassDefFoundError's class name uses "/" rather than "." for packages
val className = e.getCause.getMessage.replaceAll("/", ".")
if (spark2RemovedClasses.contains(className)) {
throw new ClassNotFoundException(s"Detected an incompatible DataSourceRegister. " +
"Please remove the incompatible library from classpath or upgrade it. " +
s"Error: ${e.getMessage}", e)
} else {
throw e
}
}
}

其实,从这个点,你可以思考一下,自己能学到多少东西:类加载,可扩展的编程思路。

主要思路是:

  1. 实现DefaultSource。

  2. 实现工厂类。

  3. 实现具体的数据加载类。

首先,主要是有三个实现吧,需要反射加载的类DefaultSource,这个名字很固定的:

1
2
3
4
5
6
7
8
package bigdata.spark.SparkSQL.DataSources

import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport}

class DefaultSource extends DataSourceV2 with ReadSupport {

def createReader(options: DataSourceOptions) = new SimpleDataSourceReader()
}

然后是,要实现DataSourceReader,负责创建阅读器工厂:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
package bigdata.spark.SparkSQL.DataSources

import org.apache.spark.sql.Row
import org.apache.spark.sql.sources.v2.reader.{DataReaderFactory, DataSourceReader}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

class SimpleDataSourceReader extends DataSourceReader {

def readSchema() = StructType(Array(StructField("value", StringType)))

def createDataReaderFactories = {
val factoryList = new java.util.ArrayList[DataReaderFactory[Row]]
factoryList.add(new SimpleDataSourceReaderFactory())
factoryList
}
}

数据源的具体实现类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
package bigdata.spark.SparkSQL.DataSources

import org.apache.spark.sql.Row
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}

class SimpleDataSourceReaderFactory extends
DataReaderFactory[Row] with DataReader[Row] {
def createDataReader = new SimpleDataSourceReaderFactory()
val values = Array("1", "2", "3", "4", "5")

var index = 0

def next = index < values.length

def get = {
val row = Row(values(index))
index = index + 1
row
}

def close() = Unit
}

使用我们默认的数据源:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
package bigdata.spark.SparkSQL.DataSources

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

object App {

def main(args: Array[String]): Unit = {
val sparkConf = new SparkConf().setAppName(this.getClass.getName).setMaster("local[*]")
.set("yarn.resourcemanager.hostname", "mt-mdh.local")
.set("spark.executor.instances","2")
.setJars(List("/opt/sparkjar/bigdata.jar"
,"/opt/jars/spark-streaming-kafka-0-10_2.11-2.3.1.jar"
,"/opt/jars/kafka-clients-0.10.2.2.jar"
,"/opt/jars/kafka_2.11-0.10.2.2.jar"))
val spark = SparkSession
.builder()
.config(sparkConf)
.getOrCreate()

val simpleDf = spark.read
.format("bigdata.spark.SparkSQL.DataSources")
.load()

simpleDf.show()
spark.stop()
}
}

format里面指定的是包路径,然后加载的时候会加上默认类名:DefaultSource。

删除多个字段

1
2
3
def dropColumns(columns: Seq[String]):DataFAME={
columns.foldLeft(dataFame)((df,column)=> df.drop(column))
}

spark sql 窗口函数

https://mp.weixin.qq.com/s/A5CiLWPdg1nkjuNImetaAw
最近理了下spark sql 中窗口函数的知识,打算开几篇文章讲讲,窗口函数有很多应用场景,比如说炒股的时候有个5日移动均线,或者让你对一个公司所有销售所有部门按照销售业绩进行排名等等,都是要用到窗口函数,在spark sql中,窗口函数和聚合函数的区别,之前文章中也提到过,就是聚合函数是按照你聚合的维度,每个分组中算出来一个聚合值,而窗口函数是对每一行,都根据当前窗口(5日均线就是今日往前的5天组成的窗口),都聚合出一个值。

1 开胃菜,spark sql 窗口函数的基本概念和使用姿势
2 spark sql 中窗口函数深入理解
3 不是UDF,也不是UDAF,教你自定义一个窗口函数(UDWF)
4 从 spark sql 源码层面理解窗口函数

什么是简单移动平均值

简单移动平均(英语:Simple Moving Average,SMA)是某变数之前n个数值的未作加权算术平均。例如,收市价的10日简单移动平均指之前10日收市价的平均数。
例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
val spark = SparkSession
.builder()
.master("local")
.appName("DateFrameFromJsonScala")
.config("spark.some.config.option", "some-value")
.getOrCreate()

import spark.implicits._

val df = List(
("站点1", "201902025", 50),
("站点1", "201902026", 90),
("站点1", "201902026", 100),
("站点2", "201902027", 70),
("站点2", "201902028", 60),
("站点2", "201902029", 40))
.toDF("site", "date", "user_cnt")

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

/**
* 这个 window spec 中,数据根据用户(customer)来分去。
* 每一个用户数据根据时间排序。然后,窗口定义从 -1(前一行)到 1(后一行)
* ,每一个滑动的窗口总用有3行
*/
val wSpce = Window.partitionBy("site").orderBy("date").rowsBetween(-1, 1)

df.withColumn("movinAvg", avg("user_cnt").over(wSpce)).show()

spark.stop()

结果:
2~3~2

1
2
3
4
5
6
7
8
9
10
+----+---------+--------+------------------+
|site| date|user_cnt| movinAvg|
+----+---------+--------+------------------+
| 站点1|201902025| 50| 70.0|
| 站点1|201902026| 90| 80.0|
| 站点1|201902026| 100| 95.0|
| 站点2|201902027| 70| 65.0|
| 站点2|201902028| 60|56.666666666666664|
| 站点2|201902029| 40| 50.0|
+----+---------+--------+------------------+

窗口函数和窗口特征定义

正如上述例子中,窗口函数主要包含两个部分:

指定窗口特征(wSpec):

“partitionyBY” 定义数据如何分组;在上面的例子中,是用户 site

“orderBy” 定义分组中的排序

“rowsBetween” 定义窗口的大小

指定窗口函数函数:

指定窗口函数函数,你可以使用 org.apache.spark.sql.functions 的“聚合函数(Aggregate Functions)”和”窗口函数(Window Functions)“类别下的函数.

累计汇总

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

val spark = SparkSession
.builder()
.master("local")
.appName("DateFrameFromJsonScala")
.config("spark.some.config.option", "some-value")
.getOrCreate()

import spark.implicits._

val df = List(
("站点1", "201902025", 50),
("站点1", "201902026", 90),
("站点1", "201902026", 100),
("站点2", "201902027", 70),
("站点2", "201902028", 60),
("站点2", "201902029", 40))
.toDF("site", "date", "user_cnt")

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

/** .rowsBetween(Long.MinValue, 0) :窗口的大小是按照排序从最小值到当前行 */
val wSpce = Window.partitionBy("site").orderBy("date").rowsBetween(Long.MinValue, 1)

df.withColumn("cumsum", sum("user_cnt").over(wSpce)).show()


spark.stop()

结果

1
2
3
4
5
6
7
8
9
10
+----+---------+--------+------+
|site| date|user_cnt|cumsum|
+----+---------+--------+------+
| 站点1|201902025| 50| 140|
| 站点1|201902026| 90| 240|
| 站点1|201902026| 100| 240|
| 站点2|201902027| 70| 130|
| 站点2|201902028| 60| 170|
| 站点2|201902029| 40| 170|
+----+---------+--------+------+

前一行数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
val spark = SparkSession
.builder()
.master("local")
.appName("DateFrameFromJsonScala")
.config("spark.some.config.option", "some-value")
.getOrCreate()

import spark.implicits._

val df = List(
("站点1", "201902025", 50),
("站点1", "201902026", 90),
("站点1", "201902026", 100),
("站点2", "201902027", 70),
("站点2", "201902028", 60),
("站点2", "201902029", 40)).toDF("site", "date", "user_cnt")

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

/** .rowsBetween(Long.MinValue, 0) :窗口的大小是按照排序从最小值到当前行 */
val wSpce = Window.partitionBy("site").orderBy("date")

df.withColumn("preUserCnt", lag(df("user_cnt"),1).over(wSpce)).show()


spark.stop()

结果

1
2
3
4
5
6
7
8
9
10
+----+---------+--------+----------+
|site| date|user_cnt|preUserCnt|
+----+---------+--------+----------+
| 站点1|201902025| 50| null|
| 站点1|201902026| 90| 50|
| 站点1|201902026| 100| 90|
| 站点2|201902027| 70| null|
| 站点2|201902028| 60| 70|
| 站点2|201902029| 40| 60|
+----+---------+--------+----------+

如果计算环比的时候,是不是特别有用啊?!

在介绍几个常用的行数:

first/last(): 提取这个分组特定排序的第一个最后一个,在获取用户退出的时候,你可能会用到

lag/lead(field, n): lead 就是 lag 相反的操作,这个用于做数据回测特别用,结果回推条件

排名

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
val spark = SparkSession
.builder()
.master("local")
.appName("DateFrameFromJsonScala")
.config("spark.some.config.option", "some-value")
.getOrCreate()

import spark.implicits._

val df = List(
("站点1", "201902025", 50),
("站点1", "201902026", 90),
("站点1", "201902026", 100),
("站点2", "201902027", 70),
("站点2", "201902028", 60),
("站点2", "201902029", 40)).toDF("site", "date", "user_cnt")

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val wSpce = Window.partitionBy("site").orderBy("date")

df.withColumn("rank", rank().over(wSpce)).show()

结果

1
2
3
4
5
6
7
8
9
10
+----+---------+--------+----+
|site| date|user_cnt|rank|
+----+---------+--------+----+
| 站点1|201902025| 50| 1|
| 站点1|201902026| 90| 2|
| 站点1|201902026| 100| 2|
| 站点2|201902027| 70| 1|
| 站点2|201902028| 60| 2|
| 站点2|201902029| 40| 3|
+----+---------+--------+----+

这个数据在提取每个分组的前n项时特别有用,省了不少麻烦。

自定义窗口函数

https://mp.weixin.qq.com/s/SMNX5lVPb0DWRf27QCcjIQ
https://github.com/zheniantoushipashi/spark-udwf-session

背景

在使用 spark sql 的时候,有时候默认提供的sql 函数可能满足不了需求,这时候可以自定义一些函数,可以自定义 UDF 或者UDAF。

UDF 在sql中只是简单的处理转换一些字段,类似默认的trim 函数把一个字符串类型的列的头尾空格去掉, UDAF函数不同于UDF,是在sql聚合语句中使用的函数,必须配合 GROUP BY 一同使用,类似默认提供的count,sum函数,但是还有一种自定义函数叫做 UDWF, 这种一般人就不知道了,这种叫做窗口自定义函数,不了解窗口函数的,可以参考 1 开胃菜,spark sql 窗口函数的基本概念和使用姿势 ,或者官方的介绍 https://databricks.com/blog/2015/07/15/introducing-window-functions-in-spark-sql.html

窗口函数是 SQL 中一类特别的函数。和聚合函数相似,窗口函数的输入也是多行记录。不同的是,聚合函数的作用于由 GROUP BY 子句聚合的组,而窗口函数则作用于一个窗口

这里怎么理解一个窗口呢,spark君在这里得好好的解释解释,一个窗口是怎么定义的,

窗口语句中,partition by用来指定分区的列,在同一个分区的行属于同一个窗口

order by用来指定窗口内的多行,如何排序

windowing_clause 用来指定开窗方式,在spark sql 中开窗方式有那么几种

  • 一个分区中的所有行作为一个窗口:UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING(上下都没有边界),这种情况下,spark sql 会把所有行作为一个输入,进行一次求值

  • Growing frame:UNBOUNDED PRECEDING AND ….(上无边界), 这种就是不断的把当前行加入的窗口中,而不删除, 例子:.rowsBetween(Long.MinValue, 0) :窗口的大小是按照排序从最小值到当前行,在数据迭代过程中,不断的把当前行加入的窗口中。

  • Shrinking frame:… AND UNBOUNDED FOLLOWING(下无边界)和Growing frame 相反,窗口不断的把迭代到的当前行从窗口中删除掉。

  • Moving frame:滑动的窗口,举例:.rowsBetween(-1, 1) 就是指窗口定义从 -1(当前行前一行)到 1(当前行后一行) ,每一个滑动的窗口总用有3行

  • Offset frame: 窗口中只有一条数据,就是偏移当前行一定距离的那一行,举例:lag(field, n): 就是取从当前行往前的n个行的字段值

这里就针对窗口函数就介绍这么多,如果不懂请参考相关文档,加强理解,我们在平时使用 spark sql 的过程中,会发现有很多教你自定义 UDF 和 UDAF 的教程,却没有针对UDWF的教程,这是为啥呢,这是因为 UDF 和UDAF 都作为上层API暴露给用户了,使用scala很简单就可以写一个函数出来,但是UDWF没有对上层用户暴露,只能使用 Catalyst expressions. 也就是Catalyst框架底层的表达式语句才可以定义,如果没有对源码有很深入的研究,根本就搞不出来。spark 君在工作中写了一些UDWF的函数,但是都比较复杂,不太好单独抽出来作为一个简明的例子给大家讲解,挑一个网上的例子来进行说明,这个例子 spark君亲测可用。

窗口函数的使用场景

我们来举个实际例子来说明 窗口函数的使用场景,在网站的统计指标中,有一个概念叫做用户会话,什么叫做用户会话呢,我来说明一下,我们在网站服务端使用用户session来管理用户状态,过程如下

1) 服务端session是用户第一次访问应用时,服务器就会创建的对象,代表用户的一次会话过程,可以用来存放数据。服务器为每一个session都分配一个唯一的sessionid,以保证每个用户都有一个不同的session对象。

2)服务器在创建完session后,会把sessionid通过cookie返回给用户所在的浏览器,这样当用户第二次及以后向服务器发送请求的时候,就会通过cookie把sessionid传回给服务器,以便服务器能够根据sessionid找到与该用户对应的session对象。

3)session通常有失效时间的设定,比如1个小时。当失效时间到,服务器会销毁之前的session,并创建新的session返回给用户。但是只要用户在失效时间内,有发送新的请求给服务器,通常服务器都会把他对应的session的失效时间根据当前的请求时间再延长1个小时。

也就是说如果用户在1个超过一个小时不产生用户事件,当前会话就结束了,如果后续再产生用户事件,就当做新的用户会话,我们现在就使用spark sql 来统计用户的会话数,首先我们先加一个列,作为当前列的session,然后再 distinct 这个列就得到用户的会话数了,所以关键是怎么加上这个newSession 这个列,这种场景就很适合使用窗口函数来做统计,因为判断当前是否是一个新会话的依据,需要依赖当前行的前一行的时间戳和当前行的时间戳的间隔来判断,下面的表格可以帮助你理解这个概念,例子中有3列数据,用户,event字段代表用户访问了一个页面产生了一个用户事件,time字段代表访问页面的时间戳:

user event time session
user1 page1 10:12 session1(new session)
user1 page2 10:20 session1(same session,8 minutes from last event)
user1 page1 11:13 session1(same session,53 minutes from last event)
user1 page3 14:12 session1(new session,3 minutes from last event)

上面只有一个用户,如果多个用户,可以使用 partition by 来进行分区。

深入研究

构造数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
case class UserActivityData(user:String, ts:Long, session:String)   

val st = System.currentTimeMillis()
val one_minute = 60 * 1000
val d = Array[UserActivityData](
UserActivityData("user1", st, "f237e656-1e53-4a24-9ad5-2b4576a4125d"),
UserActivityData("user2", st + 5*one_minute, null),
UserActivityData("user1", st + 10*one_minute, null),
UserActivityData("user1", st + 15*one_minute, null),
UserActivityData("user2", st + 15*one_minute, null),
UserActivityData("user1", st + 140*one_minute, null),
UserActivityData("user1", st + 160*one_minute, null))

"a CustomWindowFunction" should "correctly create a session " in {
val sqlContext = new SQLContext(sc)
val df = sqlContext.createDataFrame(sc.parallelize(d))

val specs = Window.partitionBy(f.col("user")).orderBy(f.col("ts").asc)
val res = df.withColumn( "newsession", MyUDWF.calculateSession(f.col("ts"), f.col("session")) over specs)

怎么使用 spark sql 来统计会话数目呢,因为不同用户产生的是不同的会话,首先使用user字段进行分区,然后按照时间戳进行排序,然后我们需要一个自定义函数来加一个列,这个列的值的逻辑如下:

1
2
3
4
IF (no previous event) create new session
ELSE (if cerrent event was past session window)
THEN create new session
ELSE use current session

运行结果如下:

我们使用 UUID 来作为会话id, 当后一行的时间戳和前一行的时间戳间隔大于1小时的时候,就创建一个新的会话id作为列值,否则使用老的会话id作为列值。

这种就涉及到状态,我们在内部需要维护的状态数据

  • 当前的session ID

  • 当前session的最后活动事件的时间戳

下面我们就看看这个怎样自定义 caculateSession 这个窗口函数,先自定义一个静态对象:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import java.util.UUID

import org.apache.spark.sql.{Column, Row}
import org.apache.spark.sql.catalyst.expressions.{Add, AggregateWindowFunction, AttributeReference, Expression, If, IsNotNull, LessThanOrEqual, Literal, ScalaUDF, Subtract}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String



object MyUDWF {
val defaultMaxSessionLengthms = 3600 * 1000
case class SessionUDWF(timestamp:Expression, session:Expression,
sessionWindow:Expression = Literal(defaultMaxSessionLengthms)) extends AggregateWindowFunction {
self: Product =>

override def children: Seq[Expression] = Seq(timestamp, session)
override def dataType: DataType = StringType

protected val zero = Literal( 0L )
protected val nullString = Literal(null:String)

protected val curentSession = AttributeReference("currentSession", StringType, nullable = true)()
protected val previousTs = AttributeReference("lastTs", LongType, nullable = false)()

override val aggBufferAttributes: Seq[AttributeReference] = curentSession :: previousTs :: Nil

protected val assignSession = If(LessThanOrEqual(Subtract(timestamp, aggBufferAttributes(1)), sessionWindow),
aggBufferAttributes(0), // if
ScalaUDF( createNewSession, StringType, children = Nil))

override val initialValues: Seq[Expression] = nullString :: zero :: Nil
override val updateExpressions: Seq[Expression] =
If(IsNotNull(session), session, assignSession) ::
timestamp ::
Nil

override val evaluateExpression: Expression = aggBufferAttributes(0)
override def prettyName: String = "makeSession"
}

protected val createNewSession = () => org.apache.spark.unsafe.types.UTF8String.fromString(UUID.randomUUID().toString)

def calculateSession(ts:Column,sess:Column): Column = withExpr { SessionUDWF(ts.expr,sess.expr, Literal(defaultMaxSessionLengthms)) }
def calculateSession(ts:Column,sess:Column, sessionWindow:Column): Column = withExpr { SessionUDWF(ts.expr,sess.expr, sessionWindow.expr) }

private def withExpr(expr: Expression): Column = new Column(expr)
}

代码说明:

  • 状态保存在 Seq[AttributeReference]中

  • 重写 initialValues方法进行初始化

  • updateExpressions 函数针对每一行数据都会调用

spark sql 在迭代处理每一行数据的时候,都会调用 updateExpressions 函数来处理,根据当后一行的时间戳和前一行的时间戳间隔大于1小时来进行不同的逻辑处理,如果不大于,就使用 aggBufferAttributes(0) 中保存的老的sessionid,如果大于,就把 createNewSession 包装为一个scalaUDF作为一个子表达式来创建一个新的sessionID,并且每次都把当前行的时间戳作为用户活动的最后时间戳。

最后包装为静态对象的方法,就可以在spark sql中使用这个自定义窗口函数了,下面是两个重载的方法,一个最大间隔时间使用默认值,一个可以运行用户自定义,perfect

现在,我们就可以拿来用在我们的main函数中了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
val st = System.currentTimeMillis()
val one_minute = 60 * 1000

val d = Array[UserActivityData](
UserActivityData("user1", st, "f237e656-1e53-4a24-9ad5-2b4576a4125d"),
UserActivityData("user2", st + 5*one_minute, null),
UserActivityData("user1", st + 10*one_minute, null),
UserActivityData("user1", st + 15*one_minute, null),
UserActivityData("user2", st + 15*one_minute, null),
UserActivityData("user1", st + 140*one_minute, null),
UserActivityData("user1", st + 160*one_minute, null))

val df = spark.createDataFrame(sc.parallelize(d))
val specs = Window.partitionBy(f.col("user")).orderBy(f.col("ts").asc)
val res = df.withColumn( "newsession", MyUDWF.calculateSession(f.col("ts"), f.col("session")) over specs)
df.show(20)
res.show(20, false)

如果我们学会了自定义 spark 窗口函数,原理是就可以处理一切这种对前后数据依赖的统计需求,不过这种自定义毕竟需要对 spark 源码有很深入的研究才可以,这就需要功力了,希望 spark君的读者可以跟着spark君日益精进。文章中的代码可能不太清楚,完整demo参考 https://github.com/zheniantoushipashi/spark-udwf-session。

本文标题:sparksql详解

文章作者:tang

发布时间:2019年01月04日 - 19:01

最后更新:2019年02月26日 - 15:02

原始链接:https://tgluon.github.io/2019/01/04/sparksql详解/

许可协议: 署名-非商业性使用-禁止演绎 4.0 国际 转载请保留原文链接及作者。

-------------本文结束感谢您的阅读-------------