1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
| @tvm.script.ir_module class MyModuleMixture: @T.prim_func def linear0( X: T.Buffer[(1, 784), "float32"], W: T.Buffer[(128, 784), "float32"], B: T.Buffer[(128,), "float32"], Z: T.Buffer[(1, 128), "float32"], ): T.func_attr({"global_symbol": "linear0", "tir.noalias": True}) ... @R.function def main( x: Tensor((1, 784), "float32"), w0: Tensor((128, 784), "float32"), b0: Tensor((128,), "float32"), w1: Tensor((10, 128), "float32"), b1: Tensor((10,), "float32"), ): with R.dataflow(): lv0 = R.call_tir(linear0, (x, w0, b0), (1, 128), dtype="float32") lv1 = R.call_tir("env.relu", (lv0,), (1, 128), dtype="float32") out = R.call_tir("env.linear", (lv1, w1, b1), (1, 10), dtype="float32") R.output(out) return out
@tvm.register_func("env.linear", override=True) def torch_linear( x: tvm.nd.NDArray, w: tvm.nd.NDArray, b: tvm.nd.NDArray, out: tvm.nd.NDArray ): ...
@tvm.register_func("env.relu", override=True) def lnumpy_relu(x: tvm.nd.NDArray, out: tvm.nd.NDArray): ...
MyModuleWithParams = relax.transform.BindParams("main", nd_params)(MyModuleMixture)
ex = relax.vm.build(MyModuleWithParams, target="llvm") vm = relax.VirtualMachine(ex, tvm.cpu())
nd_res = vm["main"](data_nd)
ftimer = vm.module.time_evaluator("main", tvm.cpu(), number=100)
|