1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
| @tvm.script.ir_module class MyModule: @T.prim_func def relu0(X: T.Buffer[(1, 128), "float32"], Y: T.Buffer[(1, 128), "float32"]): T.func_attr({"global_symbol": "relu0", "tir.noalias": True}) for i, j in T.grid(1, 128): with T.block("Y"): vi, vj = T.axis.remap("SS", [i, j]) Y[vi, vj] = T.max(X[vi, vj], T.float32(0)) @T.prim_func def linear0( X: T.Buffer[(1, 784), "float32"], W: T.Buffer[(128, 784), "float32"], B: T.Buffer[(128,), "float32"], Z: T.Buffer[(1, 128), "float32"], ): T.func_attr({"global_symbol": "linear0", "tir.noalias": True}) Y = T.alloc_buffer((1, 128), "float32") for i, j, k in T.grid(1, 128, 784): with T.block("Y"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): Y[vi, vj] = T.float32(0) Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk] for i, j in T.grid(1, 128): with T.block("Z"): vi, vj = T.axis.remap("SS", [i, j]) Z[vi, vj] = Y[vi, vj] + B[vj] @T.prim_func def linear1( X: T.Buffer[(1, 128), "float32"], W: T.Buffer[(10, 128), "float32"], B: T.Buffer[(10,), "float32"], Z: T.Buffer[(1, 10), "float32"], ): T.func_attr({"global_symbol": "linear1", "tir.noalias": True}) Y = T.alloc_buffer((1, 10), "float32") for i, j, k in T.grid(1, 10, 128): with T.block("Y"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): Y[vi, vj] = T.float32(0) Y[vi, vj] = Y[vi, vj] + X[vi, vk] * W[vj, vk] for i, j in T.grid(1, 10): with T.block("Z"): vi, vj = T.axis.remap("SS", [i, j]) Z[vi, vj] = Y[vi, vj] + B[vj] @R.function def main( x: Tensor((1, 784), "float32"), w0: Tensor((128, 784), "float32"), b0: Tensor((128,), "float32"), w1: Tensor((10, 128), "float32"), b1: Tensor((10,), "float32"), ): with R.dataflow(): lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32") lv1 = R.call_tir(relu0, (lv0,), (1, 128), dtype="float32") out = R.call_tir(linear1, (lv1, w1, b1), (1, 10), dtype="float32") R.output(out) return out
|