Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【开源任务】CINN编译器后端Pass注释添加 #70113

Open
Hongqing-work opened this issue Dec 10, 2024 · 16 comments
Open

【开源任务】CINN编译器后端Pass注释添加 #70113

Hongqing-work opened this issue Dec 10, 2024 · 16 comments
Assignees

Comments

@Hongqing-work
Copy link
Contributor

Hongqing-work commented Dec 10, 2024

一、任务背景与列表

此任务为【开源任务】CINN编译器后端Pass改造 的前置任务
深度学习编译器是一种专门为深度学习模型优化和部署而设计的工具,其功能是将高层次的深度学习模型转换为低层次的、高效的、底层硬件可执行的代码。飞桨3.0推出了与框架一体化的CINN编译器,能同时支持训练和推理过程,并且具备处理动态可变形状输入的能力。目前,CINN的编译器主要被分为两个阶段:前端与后端。前端主要在PIR层面做一些图层的优化,经过lowering之后上层表示会转化为更贴近硬件实现的后端AST IR表示,后端会在AST IR的基础上进行一系列的分析与变换,最终产生更高效的硬件实现。对IR的分析与变换在编译器中被抽象为了Pass。CINN的后端Pass之前缺乏注释,近期,为了升级Pass以及更多地和开源社区交互,后端的已有的Pass需要添加注释并进行升级,本任务为升级改造的前置任务,即理解Pass所做的变换并添加注释。

⭐️ 提交PR 模版 ⭐️:

  • // ------- PR 标题 --------
[CINN][Add Backend Pass Comment No.xxx] Add comment for IfFusion
  • // ------- PR 内容 --------
PR types
CINN

PR changes
Others

Description
为IfFusion Pass添加了注释

本期需要添加注释的pass如下,整体进展:

序号 原转换实现文件 队伍名称/状态/PR 难度
1 optim/eliminate_common_factor_of_local_index.cc @ZHOU05030
2 optim/eliminate_common_global_memory_read.cc @fxy1699 #70304
3 optim/extern_call_process.cc @Albresky #70233
4 optim/trans_buffer_with_dynamic_shape.cc @fxy1699 #70452
@jiachengdai #70449
5 optim/ir_simplify.cc @nizne9
@Albresky #70453
6 optim/longlong2int.cc @yangrongxinuser
@LittleHeroZZZX #70457
@fxy1699 #70448
7 optim/merge_block_utils.cc @PolaKuma #70213
8 optim/rearrange_load_instruction.cc @fangfangssj
@fxy1699
9 optim/remove_schedule_block.cc @nizne9
@LittleHeroZZZX #70225
10 optim/replace_cross_thread_reduction.cc @KDZZZZZZ #70227
11 optim/schedule_block_dce.cc @SCUcookie
@fxy1699 #70279 #70304
12 optim/transform_gpu_forloop.cc @hanyang2508 #70289
@fxy1699 #70296
13 optim/update_buffer_axis_pass.cc @PolaKuma #70271

看板信息

任务方向 任务数量 提交作品 / 任务认领 提交率 完成 完成率
CINN编译器后端Pass改造 13 11 / 13 84.62% 8 61.54%

统计信息

排名不分先后 @fxy1699 (2) @Albresky (1) @LittleHeroZZZX (2) @PolaKuma (2) @hanyang2508 (1)

二、任务详情

2.1 CINN编译器介绍

CINN的架构如下图所示,分为前端后端和执行器,其中前端的主要功能是基于 PIR 进行图层级别的优化,并对子图进行划分为后端高性能 Kernel 代码生成提供支持;后端主要负责将前端处理后的 IR 转换为目标硬件可执行的代码或硬件描述,主要功能包括基于硬件特性的 IR 优化、高效内存管理和代码生成等;最后再由执行器的运行调度接口对编译器生成的 Kernel 进行封装。
image
这里出现了两类IR,PIR和后端的AST IR,它们都起到了计算和数据进行表示表达的作用,上述的编译期的整个工作过程其实也可以说是对IR的分析变换,我们抽象为前后端Pass以及后端特有的编排调优Schedule
比如很简单的一个子图:

# shape of x, y is [64, 128]
def forward(self, x, y):
    tmp = x - y
    out = tmp * x
    return out

转换成PIR就变成了如下Tensor级别的高层次表示,不体现底层的计算逻辑:

{
    (%0) = "pd_op.data" [id:18] () {dtype:(pd_op.DataType)float32,name:"x",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[64,128],stop_gradient:[false]} : () -> builtin.tensor<64x128xf32> { () }	(op_18)
    (%1) = "pd_op.data" [id:19] () {dtype:(pd_op.DataType)float32,name:"y",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[64,128],stop_gradient:[false]} : () -> builtin.tensor<64x128xf32> { () }	(op_19)
    (%2) = "pd_op.subtract" [id:20] (%0, %1) {stop_gradient:[false]} : (builtin.tensor<64x128xf32>, builtin.tensor<64x128xf32>) -> builtin.tensor<64x128xf32> { () }	(op_20)
    (%3) = "pd_op.multiply" [id:21] (%2, %0) {stop_gradient:[false]} : (builtin.tensor<64x128xf32>, builtin.tensor<64x128xf32>) -> builtin.tensor<64x128xf32> { () }	(op_21)
    () = "builtin.shadow_output" [id:22] (%3) {output_name:"output_0"} : (builtin.tensor<64x128xf32>) ->  {  }	(op_22)
}

经过CINN的前端变换会得到一组组的可以融合起来的FusionOp,这里例子里只有一组subtract+multiply的FusionOp:

{
    (%0) = "pd_op.data" [id:18] () {dtype:(pd_op.DataType)float32,name:"x",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[64,128],stop_gradient:[false],sym_shape_str:"shape[64, 128], data[NULL]"} : () -> builtin.tensor<64x128xf32> { (shape[64, 128], data[NULL]) }	(op_18)
    (%1) = "pd_op.data" [id:19] () {dtype:(pd_op.DataType)float32,name:"y",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[64,128],stop_gradient:[false],sym_shape_str:"shape[64, 128], data[NULL]"} : () -> builtin.tensor<64x128xf32> { (shape[64, 128], data[NULL]) }	(op_19)
    (%2) = "cinn_op.fusion" [id:29] () -> builtin.tensor<64x128xf32> {
        (%3) = "pd_op.subtract" [id:26] (%0, %1) {stop_gradient:[false],sym_shape_str:"shape[64, 128], data[NULL]"} : (builtin.tensor<64x128xf32>, builtin.tensor<64x128xf32>) -> builtin.tensor<64x128xf32> { (shape[64, 128], data[NULL]) }	(op_26)
        (%4) = "pd_op.multiply" [id:27] (%3, %0) {stop_gradient:[false],sym_shape_str:"shape[64, 128], data[NULL]"} : (builtin.tensor<64x128xf32>, builtin.tensor<64x128xf32>) -> builtin.tensor<64x128xf32> { (shape[64, 128], data[NULL]) }	(op_27)
        (%5) = "cinn_op.yield_store" [id:28] (%4) {} : (builtin.tensor<64x128xf32>) -> builtin.tensor<64x128xf32> { (shape[64, 128], data[NULL]) }	(op_28)
        () = "cf.yield" [id:30] (%5) {} : (builtin.tensor<64x128xf32>) ->  {  }	(op_30)
    } { (shape[64, 128], data[NULL]) }	(op_29)
    () = "builtin.shadow_output" [id:22] (%2) {output_name:"output_0",sym_shape_str:"shape[64, 128], data[NULL]"} : (builtin.tensor<64x128xf32>) ->  {  }	(op_22)
}

后端会对这个FusionOp进行代码生成,然后编译成Jit Kernel以供执行器调用,这里的第一步就是需要将前端IR lowering转换成后端AST IR。AST IR更直观地表达出一个子图到底是怎么算的:

{
  ScheduleBlock(root_1)
  {
    serial for (i, 0ll, 64ll)
    {
      serial for (j, 0ll, 128ll)
      {
        ScheduleBlock(var_3)
        {
          i0_1, i1_1 = axis.bind(i, j) // 用于调度的信息
          var_3[i0_1, i1_1] = ((var[i0_1, i1_1] - var_0[i0_1, i1_1]) * var[i0_1, i1_1]) // 实际需要调度的语句
        }
      }
    }
  }
}

可以看出,两个shape为[64, 128]的tensor的相减后乘是通过两层串行的for循环实现的,循环体进行tensor特定元素的减法和乘法。除了for和加减乘除这样的常见语法,这里还出现了ScheduleBlock这个概念。这其实是为了后续的Schedule编排优化而对语句做的封装,经过Schedule之后这段代码能更贴近硬件实现比如使用32个block,每个block256个线程来完成上述的计算:

{
  ScheduleBlock(root_1)
  {
    thread_bind[blockIdx.x] for (i_j_fused, 0, 32)
    {
      thread_bind[threadIdx.x] for (i_j_fused_0, 0, 256)
      {
        ScheduleBlock(var_3)
        {
          i0_1, i1_1 = axis.bind((((i_j_fused * 256) + i_j_fused_0) / 128), (i_j_fused_0 % 128ll)) // 用于调度的信息
          read_buffers(_var[i0_1(0:64ll), i1_1(0:128ll)], _var_0[i0_1(0:64ll), i1_1(0:128ll)], _var[i0_1(0:64ll), i1_1(0:128ll)]) // 用于调度的信息
          write_buffers(_var_3[i0_1(0:64ll), i1_1(0:128ll)]) // 用于调度的信息
          var_3[i0_1, i1_1] = ((var[i0_1, i1_1] - var_0[i0_1, i1_1]) * var[i0_1, i1_1]) // 实际需要调度的语句
        }
      }
    }
  }
}

2.2 Pass注释规范

2.2.1 注释包含内容

  1. 必需:一句话概括pass内容
  2. 必需:一段话详细说明pass的应用场景,即什么情况下可以应用这个pass
  3. 必需:一段话详细说明如果可以应用这个pass,会对ir进行什么修改
  4. 可选:如果这是一个性能pass,应当说明它解决的是哪类性能问题
  5. 可选:如果这个pass有风险点,或者有TODO没完成,应该列举出来
  6. 可选:复杂pass需要举出2个例子
  7. 可选:如果这个pass有风险点,应该同时举一些反例,说明在这些情况下不能应用这个pass

2.2.2 注释格式

/**
 * [required]
 * Brief description of the pass in one sentence
 *
 * [required: Detailed application scenario]
 * This pass is applicable in scenarios where {describe the specific conditions 
 * or code patterns where this pass can be applied}. {Explain why these scenarios 
 * are common or important}.
 *
 * [required: IR modifications]
 * When applied, this pass will {describe in detail the changes made to the IR, 
 * including what is added, removed, or modified}.
 *
 * [optional: Performance impact]
 * Performance impact: This pass addresses [describe the specific performance 
 * issues this pass aims to solve, such as reducing memory usage, improving 
 * cache efficiency, etc.].
 *
 * [optional: Risks, limitations, and TODOs]
 * Risks and limitations:
 * - {List potential risks or limitations}
 * - {Add more points as needed}
 * TODO: 
 * - {List any incomplete aspects or future improvements}
 * - {Add more TODOs as needed}
 *
 * [optional: Examples]
 * Examples:
 * 1. {Example name}:
 *    Input IR:
 *      {Show input IR}
 *    Output IR:
 *      {Show output IR after pass application}
 *
 * 2. {Another example name}:
 *    Input IR:
 *      {Show input IR}
 *    Output IR:
 *      {Show output IR after pass application}
 *
 * [optional: Counter-examples (if applicable)]
 * Counter-examples (cases where the pass should not be applied):
 * 1. {Counter-example name}:
 *    {Show IR or describe situation}
 *    {Explain why the pass shouldn't be applied here}
 *
 * 2. {Another counter-example}:
 *    {Show IR or describe situation}
 *    {Explain why the pass shouldn't be applied here}
 */

三、可参考PR

#69611 paddle/cinn/optim/if_fusion_pass.h
#70092 paddle/cinn/ir/group_schedule/tactic/tile_broadcast_tactic.cc
#70258 paddle/cinn/optim/rearrange_load_instruction.h

@SCUcookie
Copy link
Contributor

【报名】:11

@yangrongxinuser
Copy link
Contributor

【报名】:6

@Albresky
Copy link
Contributor

【报名】:3

@nizne9
Copy link
Contributor

nizne9 commented Dec 12, 2024

【报名】:5、9

@fangfangssj
Copy link
Contributor

fangfangssj commented Dec 12, 2024

【报名】:8

@ZHOU05030
Copy link
Contributor

【报名】:1

@PolaKuma
Copy link
Contributor

【报名】:7

@LittleHeroZZZX
Copy link
Contributor

【报名】:9

@KDZZZZZZ
Copy link

【报名】:10

@hanyang2508
Copy link
Contributor

【报名】:12

@PolaKuma
Copy link
Contributor

【报名】:13

@fxy1699
Copy link
Contributor

fxy1699 commented Dec 17, 2024

【报名】:2、4、11、12

@LittleHeroZZZX
Copy link
Contributor

【报名】:6

@Albresky
Copy link
Contributor

【报名】:5

@fxy1699
Copy link
Contributor

fxy1699 commented Dec 25, 2024

【报名】:6、8

@jiachengdai
Copy link
Contributor

【报名】:4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: In Progress
Development

No branches or pull requests