Zhangzhe's Blog

The projection of my life.

0%

机器学习编译(5)——与机器学习框架的整合

URL

使用 Builder 创建 IRModule

从张量表达式创建 TensorIR(主张量函数)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from tvm import te
# 定义 TensorIR 输入
A = te.placeholder((128, 128), name="A", dtype="float32")
B = te.placeholder((128, 128), name="B", dtype="float32")
type(A)
# tvm.te.tensor.Tensor
A.shape
# [128, 128]
# 由张量表达式自动生成 TensorIR
def te_matmul(A: te.Tensor, B: te.Tensor) -> te.Tensor:
assert A.shape[1] == B.shape[0]
n = A.shape[0]
m = B.shape[1]
k = te.reduce_axis((0, A.shape[1]), name="k")
# 由张量表达式自动生成 TensorIR
# 调用格式是:te.compute(output_shape, lambda, TensorIR_name)
return te.compute(
(n, m), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="matmul"
)
C = te_matmul(A, B)
# 打印自动生成的 TensorIR,函数输入即为 [A, B, C]
te.create_prim_func([A, B, C]).show()
  • 输出(自动生成的主张量函数)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# from tvm.script import tir as T
@T.prim_func
def func(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"],
matmul: T.Buffer[(128, 128), "float32"],
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1, i2 in T.grid(128, 128, 128):
with T.block("matmul"):
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
T.reads(A[i, k], B[k, j])
T.writes(matmul[i, j])
with T.init():
matmul[i, j] = T.float32(0)
matmul[i, j] = matmul[i, j] + A[i, k] * B[k, j]

使用 BlockBuilder 构造 IRModule

  • 自动生成的主张量函数还需要 计算图抽象 来将计算图拼起来
1
2
3
4
5
6
7
8
9
10
11
12
A = relax.Var("A", (128, 128), relax.DynTensorType(2, "float32"))
B = relax.Var("B", (128, 128), relax.DynTensorType(2, "float32"))
# 使用 BlockBuilder 将多个张量函数拼接成一个 IRModule
bb = relax.BlockBuilder()
with bb.function("main"):
with bb.dataflow():
C = bb.emit_te(te_matmul, A, B)
D = bb.emit_te(te_relu, C)
R = bb.emit_output(D)
bb.emit_func_output(R, params=[A, B])
MyModule = bb.get()
MyModule.show()
  • 输出(IRModule)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
tvm.script.ir_module
class Module:
@T.prim_func
def te_matmul(rxplaceholder: T.Buffer[(128, 128), "float32"], rxplaceholder_1: T.Buffer[(128, 128), "float32"], matmul: T.Buffer[(128, 128), "float32"]) -> None:
...
@T.prim_func
def te_relu(rxplaceholder: T.Buffer[(128, 128), "float32"], relu: T.Buffer[(128, 128), "float32"]) -> None:
...
@R.function
def main(A: Tensor((128, 128), "float32"), B: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv = R.call_tir(te_matmul, (A, B), (128, 128), dtype="float32")
lv1 = R.call_tir(te_relu, (lv,), (128, 128), dtype="float32")
gv: Tensor((128, 128), "float32") = lv1
R.output(gv)
return gv
  • 使用 BlockBuilder 创建 IRModule 与直接创建 IRMoudle 的对比
    integration_block_builder.png
  • bb.emit_te 做了以下事情:
    • AB 创建一个输入 te.placeholder
    • 通过 te_matmul 函数运行它们
    • 调用 te.create_prim_func 来创建一个 TensorIR 函数
    • 通过 call_tir 生成对函数的调用

Pytorch 映射到 IRModule

Pytorch 模型

1
2
3
4
5
6
7
8
9
10
11
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.weight = nn.Parameter(torch.randn(128, 128))
def forward(self, x):
x = torch.matmul(x, self.weight)
x = torch.relu(x)
return x
model = MyModel()
# 生成 Pytorch 计算图
fx_module = fx.symbolic_trace(model)

构造计算图之间的映射变换

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# pytorch module parameter to IRModule parameter
def map_param(param: nn.Parameter):
ndim = len(param.data.shape)
return relax.const(
param.data.cpu().numpy(), relax.DynTensorType(ndim, "float32")
)
# pytorch module attribute to IRModule attribute
def fetch_attr(fx_mod, target: str):
"""Helper function to fetch an attr"""
target_atoms = target.split('.')
attr_itr = fx_mod
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
def from_fx(fx_mod, input_shapes, call_function_map, call_module_map):
input_index = 0
node_map = {}
named_modules = dict(fx_mod.named_modules())
bb = relax.BlockBuilder()
fn_inputs = []
fn_output = None
with bb.function("main"):
with bb.dataflow():
for node in fx_mod.graph.nodes:
if node.op == "placeholder":
# create input placeholder
shape = input_shapes[input_index]
input_index += 1
input_var = relax.Var(
node.target, shape, relax.DynTensorType(len(shape), "float32")
)
fn_inputs.append(input_var)
node_map[node] = input_var
elif node.op == "get_attr":
node_map[node] = map_param(fetch_attr(fx_mod, node.target))
elif node.op == "call_function":
node_map[node] = call_function_map[node.target](bb, node_map, node)
elif node.op == "call_module":
named_module = named_modules[node.target]
node_map[node] = call_module_map[type(named_module)](bb, node_map, node, named_module)
elif node.op == "output":
output = node_map[node.args[0]]
assert fn_output is None
fn_output = bb.emit_output(output)
# output and finalize the function
bb.emit_func_output(output, fn_inputs)
return bb.get()

映射 Pytorch ModuleTensorIR

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# TensorIR 映射变换
def map_matmul(bb, node_map, node: fx.Node):
A = node_map[node.args[0]]
B = node_map[node.args[1]]
return bb.emit_te(te_matmul, A, B)
# TensorIR 映射变换
def map_relu(bb, node_map, node: fx.Node):
A = node_map[node.args[0]]
return bb.emit_te(te_relu, A)
MyModule = from_fx(
fx_module,
input_shapes = [(1, 128)],
call_function_map = {
torch.matmul: map_matmul,
torch.relu: map_relu,
},
call_module_map={},
)
MyModule.show()
  • 映射后的 IRModule
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@tvm.script.ir_module
class Module:
@T.prim_func
def te_matmul(rxplaceholder: T.Buffer[(1, 128), "float32"], rxplaceholder_1: T.Buffer[(128, 128), "float32"], matmul: T.Buffer[(1, 128), "float32"]) -> None:
...
@T.prim_func
def te_relu(rxplaceholder: T.Buffer[(1, 128), "float32"], relu: T.Buffer[(1, 128), "float32"]) -> None:
...
@R.function
def main(x: Tensor((1, 128), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv = R.call_tir(te_matmul, (x, meta[relay.Constant][0]), (1, 128), dtype="float32")
lv1 = R.call_tir(te_relu, (lv,), (1, 128), dtype="float32")
gv: Tensor((1, 128), "float32") = lv1
R.output(gv)
return lv1

或映射到 Pytorch ModuleIRModule 更高层的算子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def map_nn_relu_op(bb, node_map, node, nn_mod):
A = node_map[node.args[0]]
return bb.emit(relax.op.relu(A))
def map_nn_linear_op(bb, node_map, node, nn_mod):
x = node_map[node.args[0]]
w = map_param(nn_mod.weight)
if nn_mod.bias is not None:
b = map_param(nn_mod.bias)
y = bb.emit(relax.op.dense(x, w))
return bb.emit(relax.op.add(y, b))
MLPModuleHighLevel = from_fx(
fx.symbolic_trace(mlp_model),
input_shapes = [(1, 784)],
call_function_map={
},
call_module_map={
torch.nn.Linear: map_nn_linear_op,
torch.nn.ReLU: map_nn_relu_op,
},
)
MLPModuleHighLevel.show()
  • 输出
1
2
3
4
5
6
7
8
9
10
11
12
13
14
@tvm.script.ir_module
class Module:
@R.function
def main(x: Tensor((1, 784), "float32")) -> Tensor(None, "float32", ndim = 2):
# block 0
with R.dataflow():
lv: Tensor((1, 128), "float32") = relax.nn.dense(x, meta[relay.Constant][0])
lv1: Tensor((1, 128), "float32") = relax.add(lv, meta[relay.Constant][1])
lv2: Tensor((1, 128), "float32") = relax.nn.relu(lv1)
lv3: Tensor((1, 10), "float32") = relax.nn.dense(lv2, meta[relay.Constant][2])
lv4: Tensor((1, 10), "float32") = relax.add(lv3, meta[relay.Constant][3])
gv: Tensor((1, 10), "float32") = lv4
R.output(gv)
return lv4

总结

  • 张量表达式 API 允许我们创建原始的 TensorIR 函数
  • BlockBuilder API 通过 emit_te 和其他函数创建 IRModule
  • 通过将模型转换为 IRModule,实现与现有的机器学习框架的整合