diff --git a/.clang-format b/.clang-format index cc15d80c..af4d5905 100644 --- a/.clang-format +++ b/.clang-format @@ -1,13 +1,50 @@ --- -BasedOnStyle: LLVM +BasedOnStyle: Google IndentWidth: 4 AccessModifierOffset: -4 +PointerAlignment: Right +DerivePointerAlignment: false +AlignEscapedNewlines: Right AlignOperands: AlignAfterOperator +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: MultiLine BreakBeforeBinaryOperators: All ColumnLimit: 120 -AllowShortBlocksOnASingleLine: Always +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyReturnTypeOnItsOwnLine: 60 +AllowShortBlocksOnASingleLine: Always +AllowShortIfStatementsOnASingleLine: Never AllowShortLoopsOnASingleLine: true +IndentCaseLabels: false +KeepEmptyLinesAtTheStartOfBlocks: true +PackConstructorInitializers: BinPack +SpacesBeforeTrailingComments: 1 +Standard: Latest InsertBraces: true +SortIncludes: CaseSensitive +IncludeBlocks: Regroup +IncludeCategories: + # C system headers. + - Regex: '^<(assert|complex|ctype|errno|fenv|float|inttypes|iso646|limits|locale|math|setjmp|signal|stdalign|stdarg|stdbool|stddef|stdint|stdio|stdlib|string|tgmath|time|uchar|wchar|wctype)\.h>$' + Priority: 1 + # C++ standard library headers. + - Regex: '^<(algorithm|any|array|atomic|barrier|bit|bitset|cassert|ccomplex|cctype|cerrno|cfenv|cfloat|charconv|chrono|cinttypes|ciso646|climits|clocale|cmath|codecvt|compare|complex|concepts|condition_variable|coroutine|csetjmp|csignal|cstdalign|cstdarg|cstdbool|cstddef|cstdint|cstdio|cstdlib|cstring|ctgmath|ctime|cuchar|cwchar|cwctype|deque|exception|execution|expected|filesystem|format|forward_list|fstream|functional|future|initializer_list|iomanip|ios|iosfwd|iostream|istream|iterator|latch|limits|list|locale|map|memory|memory_resource|mutex|new|numbers|numeric|optional|ostream|queue|random|ranges|ratio|regex|scoped_allocator|semaphore|set|shared_mutex|source_location|span|sstream|stack|stdexcept|stop_token|streambuf|string|string_view|strstream|syncstream|system_error|thread|tuple|type_traits|typeindex|typeinfo|unordered_map|unordered_set|utility|valarray|variant|vector|version)>$' + Priority: 2 + # Other external library headers, for example CUDA/MACA/NCCL/MPI. + - Regex: '^<.*>$' + Priority: 3 + # vendored third-party headers included with quotes. + - Regex: '^"(third_party/|Eigen/|gflags/|glog/)' + Priority: 4 + # Public project interfaces. + - Regex: '^"infini_train/include/' + Priority: 5 + # Internal project implementation headers. + - Regex: '^"infini_train/src/' + Priority: 6 + # Examples and other local quoted headers. + - Regex: '^".*"$' + Priority: 7 BreakBeforeBraces: Custom BraceWrapping: AfterCaseLabel: false @@ -28,4 +65,3 @@ BraceWrapping: SplitEmptyFunction: true SplitEmptyRecord: true SplitEmptyNamespace: true - \ No newline at end of file diff --git a/.github/workflows/format-check.yaml b/.github/workflows/format-check.yaml index 7c289a71..fb1bdd5d 100644 --- a/.github/workflows/format-check.yaml +++ b/.github/workflows/format-check.yaml @@ -16,12 +16,43 @@ jobs: - name: Checkout code uses: actions/checkout@v4 + - name: Install system dependencies + run: | + sudo apt-get update + sudo apt-get install -y clang-format-16 include-what-you-use + - name: Install Python dependencies run: | python3 -m pip install --upgrade pip - pip install black + pip install black colorama - name: Run format check run: | python3 scripts/format.py --path infini_train example --check + - name: Run custom style check + run: | + python3 scripts/style_check.py --path infini_train example + + - name: Configure compile database for IWYU + # Keep IWYU advisory until the existing codebase is fully cleaned up. + continue-on-error: true + run: | + cmake -S . -B build-iwyu -DUSE_CUDA=OFF -DUSE_MACA=OFF -DUSE_MPI=OFF -DUSE_OMP=OFF -DCMAKE_EXPORT_COMPILE_COMMANDS=ON + + - name: Run IWYU check + continue-on-error: true + run: | + if command -v iwyu_tool.py >/dev/null; then + IWYU_TOOL="$(command -v iwyu_tool.py)" + else + IWYU_TOOL="$(command -v iwyu_tool)" + fi + mapfile -t IWYU_SOURCES < <( + find infini_train example -type f \( -name '*.c' -o -name '*.cc' -o -name '*.cpp' -o -name '*.cxx' \) \ + ! -path 'infini_train/src/core/ccl/cuda/*' \ + ! -path 'infini_train/src/core/runtime/cuda/*' \ + ! -path 'infini_train/src/core/ccl/maca/*' \ + ! -path 'infini_train/src/core/runtime/maca/*' + ) + "${IWYU_TOOL}" -p build-iwyu -j "$(nproc)" "${IWYU_SOURCES[@]}" diff --git a/CMakeLists.txt b/CMakeLists.txt index 4c6da822..55aa3c1c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -217,6 +217,11 @@ if(BUILD_TEST) add_subdirectory(tests) endif() +if(USE_MACA) + add_executable(test_maca_allocator test/runtime/test_maca_allocator.cc) + link_infini_train_exe(test_maca_allocator) +endif() + # Negative compile test: missing dtype registration must fail at compile time. set(DTYPE_DISPATCH_COMPILE_FAIL_SOURCE ${PROJECT_SOURCE_DIR}/tests/dtype/test_dtype_dispatch_compile_fail.cc) diff --git a/example/common/tokenizer.cc b/example/common/tokenizer.cc index 9541454a..7b5e1a44 100644 --- a/example/common/tokenizer.cc +++ b/example/common/tokenizer.cc @@ -9,11 +9,12 @@ #include "glog/logging.h" -#include "example/common/utils.h" #include "infini_train/include/nn/functional.h" #include "infini_train/include/nn/modules/module.h" #include "infini_train/include/tensor.h" +#include "example/common/utils.h" + namespace infini_train { constexpr uint32_t kGpt2Eot = 50256; diff --git a/example/gpt2/checkpoint_loader.cc b/example/gpt2/checkpoint_loader.cc index 4a7789e9..5a42e32e 100644 --- a/example/gpt2/checkpoint_loader.cc +++ b/example/gpt2/checkpoint_loader.cc @@ -12,8 +12,6 @@ #include "glog/logging.h" -#include "example/common/utils.h" -#include "example/gpt2/config.h" #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/sparse.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" @@ -24,6 +22,9 @@ #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" +#include "example/common/utils.h" +#include "example/gpt2/config.h" + using namespace infini_train; namespace nn = infini_train::nn; @@ -101,7 +102,7 @@ std::shared_ptr LoadFromLLMC(const std::string &filepath) // ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ========== int pp_size = nn::parallel::global::GetPipelineParallelSize(); int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); - auto pp_rank = nn::parallel::pp_rank; + auto pp_rank = nn::parallel::tls_pp_rank; auto [is_first_stage, is_last_stage, layer_ranges_per_chunk] = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, pp_rank, vpp_size); // ========== layer to chunk ========== @@ -110,7 +111,7 @@ std::shared_ptr LoadFromLLMC(const std::string &filepath) for (int i = start; i < end; ++i) { owned_layers[i] = true; } } - auto tp_rank = nn::parallel::tp_rank; + auto tp_rank = nn::parallel::tls_tp_rank; // calculate xx_size_per_partition const int64_t vpp = model_vocab_size / tp_size; const int64_t v_start = static_cast(tp_rank) * vpp; diff --git a/example/gpt2/main.cc b/example/gpt2/main.cc index c12b5a28..56bdc5f9 100644 --- a/example/gpt2/main.cc +++ b/example/gpt2/main.cc @@ -137,7 +137,7 @@ void Train(const nn::parallel::Rank &rank) { // Set thread-local global rank // TODO(dcj): Use DeviceGuardImpl to get GlobalRank later. - nn::parallel::global::thread_global_rank = rank.GlobalRank(); + nn::parallel::global::tls_thread_global_rank = rank.GlobalRank(); const ProcessGroup *ddp_pg = nullptr; const ProcessGroup *tp_pg = nullptr; @@ -158,15 +158,14 @@ void Train(const nn::parallel::Rank &rank) { GetTensorParallelGroupRanks(rank.GlobalRank())); tp_rank = tp_pg->GetGroupRank(rank.GlobalRank()); // NOTE(zbl): Reserved for VocabParallelEmbedding - nn::parallel::tp_rank = tp_rank; + nn::parallel::tls_tp_rank = tp_rank; } if (pp_world_size > 1) { pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()), GetPipelineParallelGroupRanks(rank.GlobalRank())); pp_rank = pp_pg->GetGroupRank(rank.GlobalRank()); - - nn::parallel::pp_rank = pp_rank; + nn::parallel::tls_pp_rank = pp_rank; } } else { device = FLAGS_device == kDeviceCPU ? Device() : Device(Device::DeviceType::kCUDA, 0); diff --git a/example/llama3/checkpoint_loader.cc b/example/llama3/checkpoint_loader.cc index f29bc540..1d8d559b 100644 --- a/example/llama3/checkpoint_loader.cc +++ b/example/llama3/checkpoint_loader.cc @@ -12,8 +12,6 @@ #include "glog/logging.h" -#include "example/common/utils.h" -#include "example/llama3/config.h" #include "infini_train/include/nn/modules/normalization.h" #include "infini_train/include/nn/modules/transformer/causal_self_attention.h" #include "infini_train/include/nn/modules/transformer/mlp.h" @@ -22,6 +20,9 @@ #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" +#include "example/common/utils.h" +#include "example/llama3/config.h" + using namespace infini_train; namespace nn = infini_train::nn; @@ -86,7 +87,7 @@ std::shared_ptr LoadFromLLMC(const std::string &filepath) // ========== pp_size:num_stages; vpp_size: num_chunks_per_stage ========== int pp_size = nn::parallel::global::GetPipelineParallelSize(); int vpp_size = nn::parallel::global::GetVirtualPipelineParallelSize(); - auto pp_rank = nn::parallel::pp_rank; + auto pp_rank = nn::parallel::tls_pp_rank; auto [is_first_stage, is_last_stage, layer_ranges_per_chunk] = nn::parallel::PipelineParallel::GetStageInfo(n_layer, pp_size, pp_rank, vpp_size); // ========== layer to chunk ========== @@ -96,7 +97,7 @@ std::shared_ptr LoadFromLLMC(const std::string &filepath) } const int tp_size = nn::parallel::global::GetTensorParallelSize(); - const int tp_rank = nn::parallel::tp_rank; + const int tp_rank = nn::parallel::tls_tp_rank; CHECK_EQ(n_embd % tp_size, 0) << "n_embd must be divisible by TP world size."; CHECK_EQ(n_head % tp_size, 0) << "n_head must be divisible by TP world size."; diff --git a/example/llama3/main.cc b/example/llama3/main.cc index 117551d5..59843a52 100644 --- a/example/llama3/main.cc +++ b/example/llama3/main.cc @@ -122,7 +122,7 @@ void Train(const nn::parallel::Rank &rank) { int pp_rank = 0; // Set thread-local global rank - nn::parallel::global::thread_global_rank = rank.GlobalRank(); + nn::parallel::global::tls_thread_global_rank = rank.GlobalRank(); const ProcessGroup *ddp_pg = nullptr; const ProcessGroup *tp_pg = nullptr; @@ -143,15 +143,14 @@ void Train(const nn::parallel::Rank &rank) { GetTensorParallelGroupRanks(rank.GlobalRank())); tp_rank = tp_pg->GetGroupRank(rank.GlobalRank()); // NOTE(zbl): Reserved for VocabParallelEmbedding - nn::parallel::tp_rank = tp_rank; + nn::parallel::tls_tp_rank = tp_rank; } if (pp_world_size > 1) { pp_pg = pg_factory->GetOrCreate(GetPipelineParallelProcessGroupName(rank.GlobalRank()), GetPipelineParallelGroupRanks(rank.GlobalRank())); pp_rank = pp_pg->GetGroupRank(rank.GlobalRank()); - - nn::parallel::pp_rank = pp_rank; + nn::parallel::tls_pp_rank = pp_rank; } } else { device = FLAGS_device == kDeviceCPU ? Device() : Device(Device::DeviceType::kCUDA, 0); diff --git a/infini_train/include/autograd/grad_mode.h b/infini_train/include/autograd/grad_mode.h index 65157387..1162152c 100644 --- a/infini_train/include/autograd/grad_mode.h +++ b/infini_train/include/autograd/grad_mode.h @@ -5,13 +5,12 @@ namespace infini_train::autograd { class GradMode { public: - // Whether to enable Autograd (enabled by default) - static bool IsEnabled() { return grad_enabled_; } - static void SetEnabled(bool enabled) { grad_enabled_ = enabled; } + // Whether to enable Autograd (enabled by default). + static bool IsEnabled() { return tls_grad_enabled_; } + static void SetEnabled(bool enabled) { tls_grad_enabled_ = enabled; } private: - // grad mode should be thread_local - static thread_local bool grad_enabled_; + static thread_local bool tls_grad_enabled_; }; // RAII: Disable grad (align with torch.no_grad) diff --git a/infini_train/include/nn/parallel/global.h b/infini_train/include/nn/parallel/global.h index 9373100f..9fc9a7d0 100644 --- a/infini_train/include/nn/parallel/global.h +++ b/infini_train/include/nn/parallel/global.h @@ -6,7 +6,7 @@ namespace infini_train::nn::parallel::global { -extern thread_local int thread_global_rank; +extern thread_local int tls_thread_global_rank; enum Axis : uint8_t { DP = 0, TP = 1, PP = 2, AXIS_COUNT = 3 }; diff --git a/infini_train/include/nn/parallel/pp/pipeline_parallel.h b/infini_train/include/nn/parallel/pp/pipeline_parallel.h index 25939bdc..b7e744ae 100644 --- a/infini_train/include/nn/parallel/pp/pipeline_parallel.h +++ b/infini_train/include/nn/parallel/pp/pipeline_parallel.h @@ -16,7 +16,7 @@ namespace infini_train::nn::parallel { class PipelineStage; class PipelineSchedule; -extern thread_local int pp_rank; +extern thread_local int tls_pp_rank; struct StageInfo { bool is_first_stage; diff --git a/infini_train/include/nn/parallel/tensor_parallel.h b/infini_train/include/nn/parallel/tensor_parallel.h index 0611dfbb..ecaac3e7 100644 --- a/infini_train/include/nn/parallel/tensor_parallel.h +++ b/infini_train/include/nn/parallel/tensor_parallel.h @@ -16,7 +16,7 @@ namespace infini_train::nn::parallel { // NOTE(zbl): Reserved for VocabParallelEmbedding, since rank is needed in its constructor before any Device exists // On other occasions, should use Device::Rank() -extern thread_local int tp_rank; +extern thread_local int tls_tp_rank; class ColumnParallelLinear : public nn::CloneableModule { public: diff --git a/infini_train/include/profiler.h b/infini_train/include/profiler.h index adea9e3f..fb800cc8 100644 --- a/infini_train/include/profiler.h +++ b/infini_train/include/profiler.h @@ -17,23 +17,23 @@ namespace core { class Event; } -inline thread_local int g_profiling_depth = 0; +inline thread_local int tls_profiling_depth = 0; struct ProfileContext { std::string name; Device::DeviceType device; }; -inline thread_local ProfileContext g_profile_context; +inline thread_local ProfileContext tls_profile_context; inline void SetProfileContext(const std::string &name, Device::DeviceType device) { - if (g_profiling_depth == 0) { - g_profile_context.name = name; - g_profile_context.device = device; + if (tls_profiling_depth == 0) { + tls_profile_context.name = name; + tls_profile_context.device = device; } } -inline const ProfileContext &GetProfileContext() { return g_profile_context; } +inline const ProfileContext &GetProfileContext() { return tls_profile_context; } struct KernelProfileInfo { int64_t host_total_us = 0; @@ -89,13 +89,14 @@ class Profiler { std::string current_tag_ = "Untagged"; // thread-local tracking - thread_local static inline std::map cpu_timing_map_; + thread_local static inline std::map + tls_cpu_timing_map_; struct EventPair { core::Event *start = nullptr; core::Event *stop = nullptr; }; - thread_local static inline std::map device_timing_map_; + thread_local static inline std::map tls_device_timing_map_; }; } // namespace infini_train diff --git a/infini_train/include/utils/global_module_hook_registry.h b/infini_train/include/utils/global_module_hook_registry.h index 1e3c509e..874e5ce6 100644 --- a/infini_train/include/utils/global_module_hook_registry.h +++ b/infini_train/include/utils/global_module_hook_registry.h @@ -1,12 +1,13 @@ #pragma once -#include "infini_train/include/common/hook.h" -#include "infini_train/include/tensor.h" #include #include #include #include +#include "infini_train/include/common/hook.h" +#include "infini_train/include/tensor.h" + namespace infini_train { namespace nn { class Module; diff --git a/infini_train/src/autograd/grad_mode.cc b/infini_train/src/autograd/grad_mode.cc index 28a6e693..f09b494d 100644 --- a/infini_train/src/autograd/grad_mode.cc +++ b/infini_train/src/autograd/grad_mode.cc @@ -1,5 +1,5 @@ #include "infini_train/include/autograd/grad_mode.h" namespace infini_train::autograd { -thread_local bool GradMode::grad_enabled_ = true; +thread_local bool GradMode::tls_grad_enabled_ = true; } // namespace infini_train::autograd diff --git a/infini_train/src/core/ccl/cuda/nccl_impl.cc b/infini_train/src/core/ccl/cuda/nccl_impl.cc index 9e4b1a0d..bfb44a6e 100644 --- a/infini_train/src/core/ccl/cuda/nccl_impl.cc +++ b/infini_train/src/core/ccl/cuda/nccl_impl.cc @@ -1,8 +1,9 @@ #include "infini_train/src/core/ccl/cuda/nccl_impl.h" -#include #include +#include + #include "glog/logging.h" #include "infini_train/include/common/cuda/common_cuda.h" diff --git a/infini_train/src/kernels/cpu/embedding.cc b/infini_train/src/kernels/cpu/embedding.cc index 6b3a5aa6..a1da926f 100644 --- a/infini_train/src/kernels/cpu/embedding.cc +++ b/infini_train/src/kernels/cpu/embedding.cc @@ -46,7 +46,7 @@ std::shared_ptr EmbeddingBackward(const std::shared_ptr &input, for (int i = 0; i < input->NumElements(); ++i) { int idx = static_cast(static_cast(input->DataPtr())[i]); for (int j = 0; j < embedding_dim; ++j) { - static_cast(grad_weight->DataPtr())[idx * embedding_dim + j] // <-- 修复这里 + static_cast(grad_weight->DataPtr())[idx * embedding_dim + j] += static_cast(grad_output->DataPtr())[i * embedding_dim + j]; } } diff --git a/infini_train/src/kernels/cpu/linear.cc b/infini_train/src/kernels/cpu/linear.cc index 9d28a92a..540ec8fc 100644 --- a/infini_train/src/kernels/cpu/linear.cc +++ b/infini_train/src/kernels/cpu/linear.cc @@ -1,3 +1,5 @@ +#include "infini_train/include/autograd/linear.h" + #include #include #include diff --git a/infini_train/src/kernels/cpu/outer.cc b/infini_train/src/kernels/cpu/outer.cc index b41c9551..204fa88c 100644 --- a/infini_train/src/kernels/cpu/outer.cc +++ b/infini_train/src/kernels/cpu/outer.cc @@ -1,8 +1,9 @@ #include -#include #include #include +#include + #include "glog/logging.h" #include "infini_train/include/dispatcher.h" diff --git a/infini_train/src/nn/modules/transformer/transformer.cc b/infini_train/src/nn/modules/transformer/transformer.cc index c7e0f28c..83a6c25b 100644 --- a/infini_train/src/nn/modules/transformer/transformer.cc +++ b/infini_train/src/nn/modules/transformer/transformer.cc @@ -198,7 +198,7 @@ std::vector> TransformerLastStage::Forward(const std::ve TransformerModel::TransformerModel(const TransformerConfig config) : CloneableModule(kType), config_(config), stage_info_(nn::parallel::PipelineParallel::GetStageInfo( - config_.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::pp_rank, + config_.n_layer, nn::parallel::global::GetPipelineParallelSize(), nn::parallel::tls_pp_rank, nn::parallel::global::GetVirtualPipelineParallelSize())) { auto tp_world_size = nn::parallel::global::GetTensorParallelSize(); diff --git a/infini_train/src/nn/modules/transformer/transformer_config.cc b/infini_train/src/nn/modules/transformer/transformer_config.cc index b8947d4b..f7a7a34a 100644 --- a/infini_train/src/nn/modules/transformer/transformer_config.cc +++ b/infini_train/src/nn/modules/transformer/transformer_config.cc @@ -8,7 +8,7 @@ bool TransformerConfig::UseGQA() const { return n_kv_head < n_head; } int TransformerConfig::GetChunkSize() const { auto stage_info = parallel::PipelineParallel::GetStageInfo(n_layer, parallel::global::GetPipelineParallelSize(), - parallel::pp_rank, + parallel::tls_pp_rank, parallel::global::GetVirtualPipelineParallelSize()); return stage_info.layer_ranges_per_chunk.size(); } diff --git a/infini_train/src/nn/parallel/global.cc b/infini_train/src/nn/parallel/global.cc index 0e704647..3680f204 100644 --- a/infini_train/src/nn/parallel/global.cc +++ b/infini_train/src/nn/parallel/global.cc @@ -22,7 +22,7 @@ std::string GetEnvAsStr(const std::string &name, const std::string &default_valu namespace infini_train::nn::parallel::global { -thread_local int thread_global_rank = 0; +thread_local int tls_thread_global_rank = 0; void Layout::InitStrides() { // Calculate strides diff --git a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc index c0369cde..7361ed9e 100644 --- a/infini_train/src/nn/parallel/pp/pipeline_parallel.cc +++ b/infini_train/src/nn/parallel/pp/pipeline_parallel.cc @@ -15,7 +15,7 @@ namespace { constexpr char kModuleName[] = "module"; } // namespace -thread_local int pp_rank = 0; +thread_local int tls_pp_rank = 0; void PipelineParallel::BuildPipelineStage(const std::vector> &recv_shape, Device device, std::vector> &&chunks) { diff --git a/infini_train/src/nn/parallel/rank.cc b/infini_train/src/nn/parallel/rank.cc index 617a8e50..3423e190 100644 --- a/infini_train/src/nn/parallel/rank.cc +++ b/infini_train/src/nn/parallel/rank.cc @@ -1,4 +1,5 @@ #include "infini_train/include/nn/parallel/rank.h" + #include "infini_train/include/nn/parallel/global.h" namespace infini_train::nn::parallel { diff --git a/infini_train/src/nn/parallel/tensor_parallel.cc b/infini_train/src/nn/parallel/tensor_parallel.cc index 44ab8189..e1184594 100644 --- a/infini_train/src/nn/parallel/tensor_parallel.cc +++ b/infini_train/src/nn/parallel/tensor_parallel.cc @@ -21,7 +21,7 @@ namespace infini_train::nn::parallel { // NOTE(zbl): Reserved for VocabParallelEmbedding, since rank is needed in its constructor before any Device exists // On other occasions, should use Device::Rank() -thread_local int tp_rank = 0; +thread_local int tls_tp_rank = 0; namespace { // Comm Kernel Call Functions @@ -389,7 +389,7 @@ VocabParallelEmbedding::VocabParallelEmbedding(int64_t num_embeddings, int64_t e << "num_embeddings must be divisible by TP world size for VocabParallelEmbedding"; vocab_size_per_partition_ = num_embeddings / tp_size; - vocab_start_index_ = static_cast(tp_rank) * vocab_size_per_partition_; + vocab_start_index_ = static_cast(tls_tp_rank) * vocab_size_per_partition_; vocab_end_index_ = vocab_start_index_ + vocab_size_per_partition_; parameters_[kParamWeightName] diff --git a/infini_train/src/profiler.cc b/infini_train/src/profiler.cc index d53cc351..483b5d21 100644 --- a/infini_train/src/profiler.cc +++ b/infini_train/src/profiler.cc @@ -38,10 +38,10 @@ int GetRank(Device::DeviceType device) { } void Profiler::StartRecord(const std::string &name, Device::DeviceType device) { - if (g_profiling_depth++ > 0) { + if (tls_profiling_depth++ > 0) { return; } - cpu_timing_map_[name] = std::chrono::high_resolution_clock::now(); + tls_cpu_timing_map_[name] = std::chrono::high_resolution_clock::now(); if (device == Device::DeviceType::kCPU) { return; @@ -52,11 +52,11 @@ void Profiler::StartRecord(const std::string &name, Device::DeviceType device) { auto current_device = Device(device, static_cast(device_id)); auto *stream = impl->GetStream(current_device); - auto it = device_timing_map_.find(name); - if (it != device_timing_map_.end()) { + auto it = tls_device_timing_map_.find(name); + if (it != tls_device_timing_map_.end()) { impl->EventDestroy(it->second.start); impl->EventDestroy(it->second.stop); - device_timing_map_.erase(it); + tls_device_timing_map_.erase(it); } core::Event *start = nullptr; @@ -67,13 +67,13 @@ void Profiler::StartRecord(const std::string &name, Device::DeviceType device) { // Make sure the compute stream has done waiting, and ready for the execution of next op impl->SynchronizeStream(stream); // Start record after waiting - cpu_timing_map_[name] = std::chrono::high_resolution_clock::now(); + tls_cpu_timing_map_[name] = std::chrono::high_resolution_clock::now(); impl->EventRecord(start, stream); - device_timing_map_[name] = {start, stop}; + tls_device_timing_map_[name] = {start, stop}; } void Profiler::EndRecord(const std::string &name, Device::DeviceType device) { - if (--g_profiling_depth > 0) { + if (--tls_profiling_depth > 0) { return; } int64_t host_us = 0, device_us = 0; @@ -86,8 +86,8 @@ void Profiler::EndRecord(const std::string &name, Device::DeviceType device) { auto current_device = Device(device, static_cast(rank)); auto *stream = impl->GetStream(current_device); - auto it = device_timing_map_.find(name); - if (it == device_timing_map_.end()) { + auto it = tls_device_timing_map_.find(name); + if (it == tls_device_timing_map_.end()) { LOG(FATAL) << "Start time of " + name + " is not recorded."; } @@ -97,7 +97,7 @@ void Profiler::EndRecord(const std::string &name, Device::DeviceType device) { device_us = static_cast(impl->EventElapsedTime(event_pair.start, event_pair.stop) * 1000.0f); impl->EventDestroy(event_pair.start); impl->EventDestroy(event_pair.stop); - device_timing_map_.erase(it); + tls_device_timing_map_.erase(it); auto [peak_used_mb, peak_reserved_mb] = impl->GetMemPoolPeakMB(current_device); (void)peak_used_mb; @@ -105,10 +105,10 @@ void Profiler::EndRecord(const std::string &name, Device::DeviceType device) { device_str = current_device.ToString(); } - auto cpu_start = cpu_timing_map_[name]; + auto cpu_start = tls_cpu_timing_map_[name]; auto cpu_end = std::chrono::high_resolution_clock::now(); host_us = std::chrono::duration_cast(cpu_end - cpu_start).count(); - cpu_timing_map_.erase(name); + tls_cpu_timing_map_.erase(name); RecordKernel(name, rank, device_str, host_us, device_us, peak_mem_mb); } diff --git a/infini_train/src/utils/precision_checker.cc b/infini_train/src/utils/precision_checker.cc index d2cbd16a..97608a0b 100644 --- a/infini_train/src/utils/precision_checker.cc +++ b/infini_train/src/utils/precision_checker.cc @@ -162,7 +162,7 @@ std::ostream &GetLogStream() { std::lock_guard lock(tls_init_mutex); if (!tls_initialized) { const auto &output_path = PrecisionCheckEnv::Instance().GetOutputPath(); - int global_rank = nn::parallel::global::thread_global_rank; + int global_rank = nn::parallel::global::tls_thread_global_rank; std::string filename = output_path + "/precision_check_rank_" + std::to_string(global_rank) + ".log"; tls_log_file.open(filename, std::ios::out | std::ios::trunc); if (!tls_log_file.is_open()) { @@ -282,7 +282,7 @@ void PrecisionChecker::CheckTensors(const std::string &stage, const std::string return; } - const int rank = nn::parallel::global::thread_global_rank; + const int rank = nn::parallel::global::tls_thread_global_rank; for (size_t i = 0; i < tensors.size(); ++i) { if (!tensors[i]) { diff --git a/scripts/format.py b/scripts/format.py index 901ef3b7..dd14cde7 100644 --- a/scripts/format.py +++ b/scripts/format.py @@ -2,7 +2,13 @@ import subprocess import os from pathlib import Path -from colorama import Fore, Style +try: + from colorama import Fore, Style +except ImportError: + class _NoColor: + BLACK = RED = GREEN = YELLOW = BLUE = MAGENTA = CYAN = WHITE = RESET_ALL = "" + + Fore = Style = _NoColor() # Supported file types SUPPORTED_FILES = { @@ -202,4 +208,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/scripts/style_check.py b/scripts/style_check.py new file mode 100644 index 00000000..8a3feb8c --- /dev/null +++ b/scripts/style_check.py @@ -0,0 +1,238 @@ +import argparse +import io +import os +import re +import tokenize +from pathlib import Path + + +CPP_SUFFIXES = { + ".h", + ".hh", + ".hpp", + ".hxx", + ".c", + ".cc", + ".cpp", + ".cxx", + ".cu", + ".cuh", + ".mlu", + ".cl", +} +PY_SUFFIXES = {".py"} +EXCLUDED_DIRS = {".git", "build", "cmake-build-debug", "cmake-build-release", "third_party"} +CJK_RE = re.compile(r"[\u3400-\u9fff\uf900-\ufaff]") +THREAD_LOCAL_RE = re.compile(r"\bthread_local\b.*?;", re.DOTALL) +TLS_NAME_RE = re.compile(r"(?:(?:[A-Za-z_]\w*)::)*(?P[A-Za-z_]\w*)\s*(?:\[[^\]]*\])?\s*(?==|;|,|\{|\()") + + +def iter_files(paths): + for path in paths: + if path.is_file(): + yield path + continue + if not path.is_dir(): + print(f"error: {path} is not a file or directory") + continue + for dirpath, dirnames, filenames in os.walk(path): + dirnames[:] = sorted(d for d in dirnames if d not in EXCLUDED_DIRS) + for filename in sorted(filenames): + yield Path(dirpath) / filename + + +def has_cjk(text): + return CJK_RE.search(text) is not None + + +def blank_like(text): + chars = [] + for ch in text: + chars.append("\n" if ch == "\n" else " ") + return "".join(chars) + + +def scan_cpp(text): + stripped = [] + comments = [] + line = 1 + i = 0 + n = len(text) + + while i < n: + ch = text[i] + nxt = text[i + 1] if i + 1 < n else "" + + if ch == "\n": + stripped.append(ch) + line += 1 + i += 1 + continue + + if ch == "R" and nxt == '"': + match = re.match(r'R"([^\s\\()]*)\(', text[i:]) + if match: + end_token = ")" + match.group(1) + '"' + body_start = i + len(match.group(0)) + end = text.find(end_token, body_start) + raw_end = n if end == -1 else end + len(end_token) + segment = text[i:raw_end] + stripped.append(blank_like(segment)) + line += segment.count("\n") + i = raw_end + continue + + if ch == "/" and nxt == "/": + start_line = line + i += 2 + stripped.append(" ") + start = i + while i < n and text[i] != "\n": + stripped.append(" ") + i += 1 + comments.append((start_line, text[start:i])) + continue + + if ch == "/" and nxt == "*": + start_line = line + i += 2 + stripped.append(" ") + comment = [] + while i < n: + if i + 1 < n and text[i] == "*" and text[i + 1] == "/": + stripped.append(" ") + i += 2 + break + comment.append(text[i]) + if text[i] == "\n": + stripped.append("\n") + line += 1 + else: + stripped.append(" ") + i += 1 + comments.append((start_line, "".join(comment))) + continue + + if ch in {'"', "'"}: + quote = ch + stripped.append(" ") + i += 1 + while i < n: + current = text[i] + if current == "\n": + stripped.append("\n") + line += 1 + i += 1 + break + if current == "\\": + stripped.append(" ") + i += 1 + if i < n: + if text[i] == "\n": + stripped.append("\n") + line += 1 + else: + stripped.append(" ") + i += 1 + continue + stripped.append(" ") + i += 1 + if current == quote: + break + continue + + stripped.append(ch) + i += 1 + + return "".join(stripped), comments + + +def remove_template_args(statement): + result = [] + depth = 0 + for ch in statement: + if ch == "<": + depth += 1 + result.append(" ") + elif ch == ">" and depth > 0: + depth -= 1 + result.append(" ") + elif depth > 0: + result.append("\n" if ch == "\n" else " ") + else: + result.append(ch) + return "".join(result) + + +def declaration_prefix(statement): + limit = len(statement) + for token in ("=", "{"): + pos = statement.find(token) + if pos != -1: + limit = min(limit, pos) + return statement[:limit] + ";" + + +def check_thread_local_names(path, stripped_text): + errors = [] + for match in THREAD_LOCAL_RE.finditer(stripped_text): + statement = declaration_prefix(remove_template_args(match.group(0))) + line = stripped_text.count("\n", 0, match.start()) + 1 + for name_match in TLS_NAME_RE.finditer(statement): + name = name_match.group("name") + if not name.startswith("tls_"): + errors.append(f"{path}:{line}: thread_local variable '{name}' must use the tls_ prefix") + return errors + + +def check_cpp_comments(path, comments): + errors = [] + for line, comment in comments: + if has_cjk(comment): + errors.append(f"{path}:{line}: comments must be written in English") + return errors + + +def check_python_comments(path, text): + errors = [] + try: + tokens = tokenize.generate_tokens(io.StringIO(text).readline) + for token in tokens: + if token.type == tokenize.COMMENT and has_cjk(token.string): + errors.append(f"{path}:{token.start[0]}: comments must be written in English") + except tokenize.TokenError as exc: + errors.append(f"{path}: could not tokenize Python file: {exc}") + return errors + + +def check_file(path): + suffix = path.suffix + if suffix not in CPP_SUFFIXES and suffix not in PY_SUFFIXES: + return [] + + text = path.read_text(encoding="utf-8", errors="replace") + if suffix in CPP_SUFFIXES: + stripped_text, comments = scan_cpp(text) + return check_thread_local_names(path, stripped_text) + check_cpp_comments(path, comments) + return check_python_comments(path, text) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--path", nargs="+", type=Path, required=True, help="Files or directories to check.") + args = parser.parse_args() + + errors = [] + for file in iter_files(args.path): + errors.extend(check_file(file)) + + if errors: + for error in errors: + print(error) + raise SystemExit(1) + + print("Style check passed.") + + +if __name__ == "__main__": + main()