1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
| @torch.no_grad() def speculative_sampling( prefix: torch.Tensor, approx_model: torch.nn.Module, target_model: torch.nn.Module, max_len: int, gamma: int = 4, temperature: float = 1, top_k: int = 0, top_p: float = 0, ) -> torch.Tensor: seq_len = prefix.shape[1] T = seq_len + max_len approx_model_cache = KVCacheModel(approx_model, temperature, top_k, top_p) target_model_cache = KVCacheModel(target_model, temperature, top_k, top_p) resample_count = 0 target_sample_count = 0 accepted_count = 0 while prefix.shape[1] < T: prefix_len = prefix.shape[1] x = approx_model_cache.generate(prefix, gamma) _ = target_model_cache.generate(x, 1) n = prefix_len + gamma - 1 for i in range(gamma): r = torch.rand(1, device=target_model_cache.device) j = x[:, prefix_len + i] if r > (target_model_cache._prob_history[:, prefix_len + i - 1, j]) / ( approx_model_cache._prob_history[:, prefix_len + i - 1, j] ): n = prefix_len + i - 1 break accepted_count += 1 prefix = x[:, : n + 1] approx_model_cache.rollback(n + 1) if n < prefix_len + gamma - 1: t = sample( max_fn( target_model_cache._prob_history[:, n, :] - approx_model_cache._prob_history[:, n, :] ) ) resample_count += 1 target_model_cache.rollback(n + 1) else: assert n == target_model_cache._prob_history.shape[1] - 1 t = sample(target_model_cache._prob_history[:, -1, :]) target_sample_count += 1 prefix = torch.cat((prefix, t), dim=1) return prefix
|