diff --git a/infini_train/include/dispatcher.h b/infini_train/include/dispatcher.h index 29d11b73..638df76a 100644 --- a/infini_train/include/dispatcher.h +++ b/infini_train/include/dispatcher.h @@ -74,6 +74,9 @@ class Dispatcher { template RetT Call(KeyT key, ArgsT... args) const { auto kernel = this->GetKernel(key); tls_autocast_context.Autocast(key, args...); +#ifdef PROFILE_MODE + SetProfileContext(key.second, key.first); +#endif return kernel.Call(std::forward(args)...); }