Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion csrc/backends/attention_backends.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ enum class AttentionBackend {
STATIC_ATTN,
PAGED_ATTN,
FLASH_ATTN,
HYBRID_ATTN,
FLASHINFER,
Default = STATIC_ATTN
};
Expand All @@ -25,6 +26,8 @@ inline std::ostream &operator<<(std::ostream &os, AttentionBackend backend) {
return os << "AttentionBackend::PAGED_ATTN";
case AttentionBackend::FLASH_ATTN:
return os << "AttentionBackend::FLASH_ATTN";
case AttentionBackend::HYBRID_ATTN:
return os << "AttentionBackend::HYBRID_ATTN";
case AttentionBackend::FLASHINFER:
return os << "AttentionBackend::FLASHINFER";
default:
Expand All @@ -46,12 +49,15 @@ inline AttentionBackend parse_attention_backend(const std::string &backend) {
if (backend == "flash-attn") {
return AttentionBackend::FLASH_ATTN;
}
if (backend == "hybrid-attn") {
return AttentionBackend::HYBRID_ATTN;
}
if (backend == "flashinfer") {
return AttentionBackend::FLASHINFER;
}

throw std::invalid_argument(
"Invalid attention_backend: " + backend + ". Valid options are: static-attn, paged-attn, flash-attn, flashinfer");
"Invalid attention_backend: " + backend + ". Valid options are: static-attn, paged-attn, flash-attn, hybrid-attn, flashinfer");
}

} // namespace infinilm::backends
3 changes: 3 additions & 0 deletions csrc/layers/attention/backends/attention_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ AttentionLayer::AttentionLayer(size_t num_heads,
case ::infinilm::backends::AttentionBackend::FLASH_ATTN:
attn_backend_impl_ = std::make_shared<backends::FlashAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx);
break;
case ::infinilm::backends::AttentionBackend::HYBRID_ATTN:
attn_backend_impl_ = std::make_shared<backends::HybridAttentionImpl>(num_heads, head_size, scale, num_kv_heads, layer_idx);
break;
default:
throw std::runtime_error("infinilm::layers::attention::AttentionLayer: unsupported attention backend");
}
Expand Down
3 changes: 2 additions & 1 deletion csrc/layers/attention/backends/attention_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
#include "../../../backends/attention_backends.hpp"
#include "../../../global_state/global_state.hpp"
#include "flash_attn.hpp"
#include "hybrid_attn.hpp"
#include "infinicore/tensor.hpp"
#include "paged_attn.hpp"
#include "static_attn.hpp"
#include <memory>
#include <variant>

namespace infinilm::layers::attention {
using AttentionImpl = std::variant<std::shared_ptr<backends::StaticAttentionImpl>, std::shared_ptr<backends::PagedAttentionImpl>, std::shared_ptr<backends::FlashAttentionImpl>>;
using AttentionImpl = std::variant<std::shared_ptr<backends::StaticAttentionImpl>, std::shared_ptr<backends::PagedAttentionImpl>, std::shared_ptr<backends::FlashAttentionImpl>, std::shared_ptr<backends::HybridAttentionImpl>>;

/**
* @brief Attention layer.
Expand Down
95 changes: 95 additions & 0 deletions csrc/layers/attention/backends/hybrid_attn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include "hybrid_attn.hpp"

#include "../../../utils.hpp"
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/ops/paged_attention.hpp"
#include "infinicore/ops/paged_caching.hpp"

#include <stdexcept>

namespace infinilm::layers::attention::backends {

HybridAttentionImpl::HybridAttentionImpl(size_t num_heads,
size_t head_size,
float scale,
size_t num_kv_heads,
size_t layer_idx)
: num_heads_(num_heads),
scale_(scale),
head_dim_(head_size) {

const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config();
if (!infinilm_config.model_config) {
throw std::runtime_error("infinilm::layers::attention::backends::HybridAttentionImpl: model_config is null");
}
max_position_embeddings_ = infinilm_config.model_config->get<size_t>("max_position_embeddings");
}

infinicore::Tensor HybridAttentionImpl::forward(const AttentionLayer &layer,
const infinicore::Tensor &query,
const infinicore::Tensor &key,
const infinicore::Tensor &value,
infinicore::Tensor &kv_cache,
const infinilm::global_state::AttentionMetadata &attn_metadata) const {
auto total_sequence_lengths = attn_metadata.total_sequence_lengths;
auto input_offsets = attn_metadata.input_offsets;
auto block_tables = attn_metadata.block_tables;
auto slot_mapping = attn_metadata.slot_mapping;
auto cu_seqlens = attn_metadata.cu_seqlens;

ASSERT(total_sequence_lengths.has_value());
ASSERT(input_offsets.has_value());
ASSERT(block_tables.has_value());
ASSERT(slot_mapping.has_value());
ASSERT(cu_seqlens.has_value());

auto [k_total, v_total] = do_kv_cache_update(key, value, kv_cache, slot_mapping.value());

size_t seq_len = query->shape()[0];
bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]);

infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device());
if (is_prefill) {
infinicore::op::mha_varlen_(
attn_output,
query,
key,
value,
input_offsets.value(),
cu_seqlens.value(),
block_tables.value(),
max_position_embeddings_,
max_position_embeddings_,
std::nullopt,
scale_);
} else {
infinicore::op::paged_attention_(
attn_output,
query,
k_total,
v_total,
block_tables.value(),
total_sequence_lengths.value(),
std::nullopt,
scale_);
}
return attn_output->view({1, seq_len, num_heads_ * head_dim_});
}

std::tuple<infinicore::Tensor, infinicore::Tensor> HybridAttentionImpl::do_kv_cache_update(const infinicore::Tensor key,
const infinicore::Tensor value,
infinicore::Tensor &kv_cache,
const infinicore::Tensor slot_mapping) const {
auto k_cache_layer = kv_cache->narrow({{0, 0, 1}})->squeeze(0);
auto v_cache_layer = kv_cache->narrow({{0, 1, 1}})->squeeze(0);
infinicore::op::paged_caching_(
k_cache_layer,
v_cache_layer,
key,
value,
slot_mapping);

return {k_cache_layer, v_cache_layer};
}

} // namespace infinilm::layers::attention::backends
40 changes: 40 additions & 0 deletions csrc/layers/attention/backends/hybrid_attn.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#pragma once

#include "../../../global_state/global_state.hpp"
#include "infinicore/tensor.hpp"
#include <tuple>

namespace infinilm::layers::attention {
class AttentionLayer;
}

namespace infinilm::layers::attention::backends {

class HybridAttentionImpl {
public:
HybridAttentionImpl(size_t num_heads,
size_t head_size,
float scale,
size_t num_kv_heads,
size_t layer_idx);

infinicore::Tensor forward(const AttentionLayer &layer,
const infinicore::Tensor &query,
const infinicore::Tensor &key,
const infinicore::Tensor &value,
infinicore::Tensor &kv_cache,
const infinilm::global_state::AttentionMetadata &attn_metadata) const;

std::tuple<infinicore::Tensor, infinicore::Tensor> do_kv_cache_update(const infinicore::Tensor key,
const infinicore::Tensor value,
infinicore::Tensor &kv_cache,
const infinicore::Tensor slot_mapping) const;

private:
size_t num_heads_;
float scale_;
size_t head_dim_;
size_t max_position_embeddings_;
};

} // namespace infinilm::layers::attention::backends
1 change: 1 addition & 0 deletions csrc/models/infinilm_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ std::vector<infinicore::Tensor> InfinilmModel::default_allocate_kv_cache_tensors
case backends::AttentionBackend::FLASH_ATTN: {
;
}
case backends::AttentionBackend::HYBRID_ATTN:
case backends::AttentionBackend::PAGED_ATTN: {
auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config);
if (nullptr == paged_kv_cache_config) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
namespace infinilm::models::minicpm_sala {

std::vector<infinicore::Tensor> minicpm_sala_allocate_kv_cache_tensors(const cache::CacheConfig *cache_config,
const std::shared_ptr<infinilm::config::ModelConfig> &text_config,
const backends::AttentionBackend &attention_backend) {
const std::shared_ptr<infinilm::config::ModelConfig> &text_config,
const backends::AttentionBackend &attention_backend) {
if (nullptr == cache_config) {
return {};
}
Expand Down Expand Up @@ -58,6 +58,7 @@ std::vector<infinicore::Tensor> minicpm_sala_allocate_kv_cache_tensors(const cac
}
break;
}
case backends::AttentionBackend::HYBRID_ATTN:
case backends::AttentionBackend::PAGED_ATTN: {
auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config);
if (nullptr == paged_kv_cache_config) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ std::vector<infinicore::Tensor> qwen3_next_allocate_kv_cache_tensors(
case backends::AttentionBackend::FLASH_ATTN: {
;
}
case backends::AttentionBackend::HYBRID_ATTN:
case backends::AttentionBackend::PAGED_ATTN: {
auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config);
if (nullptr == paged_kv_cache_config) {
Expand Down
140 changes: 140 additions & 0 deletions docs/hybrid_attn_iluvatar.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Hybrid Attention on Iluvatar

本文档记录 Iluvatar 平台上的 `hybrid-attn` 路径:Prefill 使用 FlashAttention-2 varlen,Decode 使用 InfiniCore 原生 PagedAttention。

## 改了什么

### InfiniCore

- Iluvatar 被作为 CUDA-compatible ATen device 处理,用于复用 ATen/CUDA tensor 和 stream guard。
- Iluvatar 的 FlashAttention-2 调用走全局 `flash_attn_2_cuda` ABI。
- 适配 Iluvatar 当前 `mha_varlen_fwd` 尾部参数。
- `mha_varlen` 和 `mha_kvcache` 的 FlashAttention 路径使用当前 InfiniCore stream。
- Iluvatar + `--flash-attn` 构建时,`libinfinicore_cpp_api.so` 链接 `flash_attn_2_cuda*.so` 并写入 rpath。
- 构建时同步 PyTorch 的 `_GLIBCXX_USE_CXX11_ABI`,避免 Python 扩展加载时 ABI 符号不匹配。

### InfiniLM

- 新增 `hybrid-attn` attention backend。
- 新增独立 `HybridAttentionImpl`,不改变纯 `flash-attn` 语义。
- `hybrid-attn` 的执行路径:
- Prefill:使用 FA2 varlen,输入为本轮 dense `query/key/value`。
- Decode:使用原生 `paged_attention_`,输入为 paged KV cache。
- `hybrid-attn` 使用 paged KV cache 分配。
- Python CLI/API 层会把 `hybrid-attn` 归一化为 paged cache 路径,避免 cache 类型误配。

## 怎么使用

以下命令以 `/data-aisoft/qyq_models/Qwen2.5-3B-Instruct` 为例。

### 1. 环境变量

FA2 的 `.so` 路径可以直接通过 Python 获取:

```bash
export FLASH_ATTN_2_CUDA_SO=$(python3 -c 'import flash_attn_2_cuda; print(flash_attn_2_cuda.__file__)')
export LD_LIBRARY_PATH=/root/.infini/lib:/usr/local/corex/lib64:/usr/local/corex/lib64/python3/dist-packages/torch/lib:/usr/local/corex/lib64/python3/dist-packages:$LD_LIBRARY_PATH
export PYTHONPATH=/home/zx/InfiniLM/python:/home/zx/InfiniCore/python:/usr/local/corex/lib64/python3/dist-packages:$PYTHONPATH
```

### 2. 构建 InfiniCore

```bash
cd /home/zx/InfiniCore
xmake f --iluvatar-gpu=y --aten=y --flash-attn=/usr/local/corex/lib64/python3/dist-packages
xmake build infinicore_cpp_api
xmake build _infinicore
xmake install -o /root/.infini infinicore_cpp_api
xmake install -o /root/.infini _infinicore
```

同步本地 Python 包中的 InfiniCore 扩展:

```bash
cp -f /root/.infini/lib/libinfinicore_cpp_api.so /home/zx/InfiniCore/python/infinicore/lib/libinfinicore_cpp_api.so
cp -f /root/.infini/lib/_infinicore.cpython-310-x86_64-linux-gnu.so /home/zx/InfiniCore/python/infinicore/lib/_infinicore.cpython-310-x86_64-linux-gnu.so
```

可选检查 `libinfinicore_cpp_api.so` 是否已经链接 FA2:

```bash
readelf -d /root/.infini/lib/libinfinicore_cpp_api.so | grep flash_attn_2_cuda
```

预期能看到类似:

```text
NEEDED Shared library: [/usr/local/corex/lib64/python3/dist-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so]
```

### 3. 构建 InfiniLM

```bash
cd /home/zx/InfiniLM
xmake build _infinilm
xmake install _infinilm
```

### 4. 运行 hybrid-attn 推理

```bash
cd /home/zx/InfiniLM
python3 examples/test_infer.py \
--model /data-aisoft/qyq_models/Qwen2.5-3B-Instruct \
--device iluvatar \
--enable-paged-attn \
--attn hybrid-attn \
--batch-size 1 \
--max-new-tokens 4 \
--prompt "你好" \
--temperature 0.0 \
--top-k 1
```

说明:当前 CLI/API 会将 `hybrid-attn` 归一化到 paged cache 路径;命令中保留 `--enable-paged-attn` 是为了显式表达运行条件。

## Qwen2.5 运行结果

验证环境:

- Platform:Iluvatar
- Model:`/data-aisoft/qyq_models/Qwen2.5-3B-Instruct`
- Attention backend:`hybrid-attn`
- Batch size:1
- Max new tokens:4
- Prompt:`你好`

构建验证:

```text
xmake build infinicore_cpp_api # passed
xmake build _infinicore # passed
xmake build _infinilm # passed
```

推理复现结果:

```text
load weights over! 2431.8737983703613 ms

=================== start generate ====================
Generating: 100%|██████████| 1/1 [00:02<00:00, 2.53s/it]
Resquest 0:
===Query===
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
你好<|im_end|>
<|im_start|>assistant

===Response===
""""

total_time: 2582.32 ms
```

## 当前边界

- 当前稳定验证路径是 Iluvatar + Qwen2.5 + FA2 dense prefill + native paged decode。
- Iluvatar 当前 FA2 varlen 不使用 paged KV cache layout 作为 prefill 输入,hybrid prefill 使用本轮 dense `key/value`。
- `flash-attn` 仍表示纯 FA 路径;`hybrid-attn` 是单独 backend。
5 changes: 4 additions & 1 deletion python/infinilm/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ def __init__(self):
# Multimodal parameters
self.image = self.args.image

if self.attn == "hybrid-attn":
self.enable_paged_attn = True

if self.enable_paged_attn and self.attn == "default":
self.attn = "paged-attn"

Expand All @@ -119,7 +122,7 @@ def _add_common_args(self):
"--attn",
type=str,
default="default",
choices=["default", "paged-attn", "flash-attn"],
choices=["default", "paged-attn", "flash-attn", "hybrid-attn"],
)
self.parser.add_argument("--enable-graph", action="store_true")
self.parser.add_argument(
Expand Down
Loading