diff --git a/conditioner.hpp b/conditioner.hpp index 45db314b9..7fb1f0cd5 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -10,9 +10,14 @@ struct SDCondition { struct ggml_tensor* c_vector = nullptr; // aka y struct ggml_tensor* c_concat = nullptr; + std::vector extra_c_crossattns; + SDCondition() = default; - SDCondition(struct ggml_tensor* c_crossattn, struct ggml_tensor* c_vector, struct ggml_tensor* c_concat) - : c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat) {} + SDCondition(struct ggml_tensor* c_crossattn, + struct ggml_tensor* c_vector, + struct ggml_tensor* c_concat, + const std::vector& extra_c_crossattns = {}) + : c_crossattn(c_crossattn), c_vector(c_vector), c_concat(c_concat), extra_c_crossattns(extra_c_crossattns) {} }; struct ConditionerParams { @@ -1657,18 +1662,23 @@ struct LLMEmbedder : public Conditioner { } std::tuple, std::vector> tokenize(std::string text, - std::pair attn_range, + const std::pair& attn_range, size_t max_length = 0, bool padding = false) { std::vector> parsed_attention; - parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); - if (attn_range.second - attn_range.first > 0) { - auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); - parsed_attention.insert(parsed_attention.end(), - new_parsed_attention.begin(), - new_parsed_attention.end()); - } - parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); + if (attn_range.first >= 0 && attn_range.second > 0) { + parsed_attention.emplace_back(text.substr(0, attn_range.first), 1.f); + if (attn_range.second - attn_range.first > 0) { + auto new_parsed_attention = parse_prompt_attention(text.substr(attn_range.first, attn_range.second - attn_range.first)); + parsed_attention.insert(parsed_attention.end(), + new_parsed_attention.begin(), + new_parsed_attention.end()); + } + parsed_attention.emplace_back(text.substr(attn_range.second), 1.f); + } else { + parsed_attention.emplace_back(text, 1.f); + } + { std::stringstream ss; ss << "["; @@ -1699,78 +1709,161 @@ struct LLMEmbedder : public Conditioner { return {tokens, weights}; } + ggml_tensor* encode_prompt(ggml_context* work_ctx, + int n_threads, + const std::string prompt, + const std::pair& prompt_attn_range, + int max_length, + int min_length, + std::vector> image_embeds, + const std::set& out_layers, + int prompt_template_encode_start_idx) { + auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0); + auto& tokens = std::get<0>(tokens_and_weights); + auto& weights = std::get<1>(tokens_and_weights); + + struct ggml_tensor* hidden_states = nullptr; // [N, n_token, hidden_size] + + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); + + llm->compute(n_threads, + input_ids, + image_embeds, + out_layers, + &hidden_states, + work_ctx); + { + auto tensor = hidden_states; + float original_mean = ggml_ext_tensor_mean(tensor); + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2); + value *= weights[i1]; + ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2); + } + } + } + float new_mean = ggml_ext_tensor_mean(tensor); + ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean)); + } + + GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx); + + int64_t zero_pad_len = 0; + if (min_length > 0) { + if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) { + zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx; + } + } + + ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx, + GGML_TYPE_F32, + hidden_states->ne[0], + hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len, + hidden_states->ne[2]); + + ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { + float value = 0.f; + if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) { + value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3); + } + ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3); + }); + + return new_hidden_states; + } + SDCondition get_learned_condition(ggml_context* work_ctx, int n_threads, const ConditionerParams& conditioner_params) override { std::string prompt; - std::vector> image_embeds; std::pair prompt_attn_range; + std::vector extra_prompts; + std::vector> extra_prompts_attn_range; + std::vector> image_embeds; int prompt_template_encode_start_idx = 34; int max_length = 0; + int min_length = 0; std::set out_layers; - if (llm->enable_vision && conditioner_params.ref_images.size() > 0) { - LOG_INFO("QwenImageEditPlusPipeline"); - prompt_template_encode_start_idx = 64; - int image_embed_idx = 64 + 6; - - int min_pixels = 384 * 384; - int max_pixels = 560 * 560; - std::string placeholder = "<|image_pad|>"; - std::string img_prompt; - - for (int i = 0; i < conditioner_params.ref_images.size(); i++) { - sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); - double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; - int height = image.height; - int width = image.width; - int h_bar = static_cast(std::round(height / factor)) * factor; - int w_bar = static_cast(std::round(width / factor)) * factor; - - if (static_cast(h_bar) * w_bar > max_pixels) { - double beta = std::sqrt((height * width) / static_cast(max_pixels)); - h_bar = std::max(static_cast(factor), - static_cast(std::floor(height / beta / factor)) * static_cast(factor)); - w_bar = std::max(static_cast(factor), - static_cast(std::floor(width / beta / factor)) * static_cast(factor)); - } else if (static_cast(h_bar) * w_bar < min_pixels) { - double beta = std::sqrt(static_cast(min_pixels) / (height * width)); - h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); - w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); - } - LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); + int64_t t0 = ggml_time_ms(); + + if (sd_version_is_qwen_image(version)) { + if (llm->enable_vision && !conditioner_params.ref_images.empty() > 0) { + LOG_INFO("QwenImageEditPlusPipeline"); + prompt_template_encode_start_idx = 64; + int image_embed_idx = 64 + 6; + + int min_pixels = 384 * 384; + int max_pixels = 560 * 560; + std::string placeholder = "<|image_pad|>"; + std::string img_prompt; + + for (int i = 0; i < conditioner_params.ref_images.size(); i++) { + sd_image_f32_t image = sd_image_t_to_sd_image_f32_t(*conditioner_params.ref_images[i]); + double factor = llm->params.vision.patch_size * llm->params.vision.spatial_merge_size; + int height = image.height; + int width = image.width; + int h_bar = static_cast(std::round(height / factor)) * factor; + int w_bar = static_cast(std::round(width / factor)) * factor; + + if (static_cast(h_bar) * w_bar > max_pixels) { + double beta = std::sqrt((height * width) / static_cast(max_pixels)); + h_bar = std::max(static_cast(factor), + static_cast(std::floor(height / beta / factor)) * static_cast(factor)); + w_bar = std::max(static_cast(factor), + static_cast(std::floor(width / beta / factor)) * static_cast(factor)); + } else if (static_cast(h_bar) * w_bar < min_pixels) { + double beta = std::sqrt(static_cast(min_pixels) / (height * width)); + h_bar = static_cast(std::ceil(height * beta / factor)) * static_cast(factor); + w_bar = static_cast(std::ceil(width * beta / factor)) * static_cast(factor); + } + + LOG_DEBUG("resize conditioner ref image %d from %dx%d to %dx%d", i, image.height, image.width, h_bar, w_bar); - sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); - free(image.data); - image.data = nullptr; + sd_image_f32_t resized_image = clip_preprocess(image, w_bar, h_bar); + free(image.data); + image.data = nullptr; - ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); - sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); - free(resized_image.data); - resized_image.data = nullptr; + ggml_tensor* image_tensor = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, resized_image.width, resized_image.height, 3, 1); + sd_image_f32_to_ggml_tensor(resized_image, image_tensor, false); + free(resized_image.data); + resized_image.data = nullptr; - ggml_tensor* image_embed = nullptr; - llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); - image_embeds.emplace_back(image_embed_idx, image_embed); - image_embed_idx += 1 + image_embed->ne[1] + 6; + ggml_tensor* image_embed = nullptr; + llm->encode_image(n_threads, image_tensor, &image_embed, work_ctx); + image_embeds.emplace_back(image_embed_idx, image_embed); + image_embed_idx += 1 + image_embed->ne[1] + 6; - img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] - int64_t num_image_tokens = image_embed->ne[1]; - img_prompt.reserve(num_image_tokens * placeholder.size()); - for (int j = 0; j < num_image_tokens; j++) { - img_prompt += placeholder; + img_prompt += "Picture " + std::to_string(i + 1) + ": <|vision_start|>"; // [24669, 220, index, 25, 220, 151652] + int64_t num_image_tokens = image_embed->ne[1]; + img_prompt.reserve(num_image_tokens * placeholder.size()); + for (int j = 0; j < num_image_tokens; j++) { + img_prompt += placeholder; + } + img_prompt += "<|vision_end|>"; } - img_prompt += "<|vision_end|>"; - } - prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; - prompt += img_prompt; + prompt = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n"; + prompt += img_prompt; - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); + + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } else { + prompt_template_encode_start_idx = 34; + + prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n"; + + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); - prompt += "<|im_end|>\n<|im_start|>assistant\n"; + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } } else if (sd_version_is_flux2(version)) { prompt_template_encode_start_idx = 0; out_layers = {10, 20, 30}; @@ -1786,13 +1879,23 @@ struct LLMEmbedder : public Conditioner { prompt_template_encode_start_idx = 0; out_layers = {35}; // -2 - prompt = "<|im_start|>user\n"; + if (!conditioner_params.ref_images.empty()) { + LOG_INFO("ZImageOmniPipeline"); + prompt = "<|im_start|>user\n<|vision_start|>"; + for (int i = 0; i < conditioner_params.ref_images.size() - 1; i++) { + extra_prompts.push_back("<|vision_end|><|vision_start|>"); + } + extra_prompts.push_back("<|vision_end|>" + conditioner_params.text + "<|im_end|>\n<|im_start|>assistant\n<|vision_start|>"); + extra_prompts.push_back("<|vision_end|><|im_end|>"); + } else { + prompt = "<|im_start|>user\n"; - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); + prompt_attn_range.first = static_cast(prompt.size()); + prompt += conditioner_params.text; + prompt_attn_range.second = static_cast(prompt.size()); - prompt += "<|im_end|>\n<|im_start|>assistant\n"; + prompt += "<|im_end|>\n<|im_start|>assistant\n"; + } } else if (sd_version_is_flux2(version)) { prompt_template_encode_start_idx = 0; out_layers = {10, 20, 30}; @@ -1804,6 +1907,8 @@ struct LLMEmbedder : public Conditioner { prompt_attn_range.second = prompt.size(); prompt += "[/INST]"; + + min_length = 512; } else if (version == VERSION_OVIS_IMAGE) { prompt_template_encode_start_idx = 28; max_length = prompt_template_encode_start_idx + 256; @@ -1816,81 +1921,36 @@ struct LLMEmbedder : public Conditioner { prompt += "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"; } else { - prompt_template_encode_start_idx = 34; - - prompt = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n"; - - prompt_attn_range.first = static_cast(prompt.size()); - prompt += conditioner_params.text; - prompt_attn_range.second = static_cast(prompt.size()); - - prompt += "<|im_end|>\n<|im_start|>assistant\n"; - } - - auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0); - auto& tokens = std::get<0>(tokens_and_weights); - auto& weights = std::get<1>(tokens_and_weights); - - int64_t t0 = ggml_time_ms(); - struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584] - - auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens); - - llm->compute(n_threads, - input_ids, - image_embeds, - out_layers, - &hidden_states, - work_ctx); - { - auto tensor = hidden_states; - float original_mean = ggml_ext_tensor_mean(tensor); - for (int i2 = 0; i2 < tensor->ne[2]; i2++) { - for (int i1 = 0; i1 < tensor->ne[1]; i1++) { - for (int i0 = 0; i0 < tensor->ne[0]; i0++) { - float value = ggml_ext_tensor_get_f32(tensor, i0, i1, i2); - value *= weights[i1]; - ggml_ext_tensor_set_f32(tensor, value, i0, i1, i2); - } - } - } - float new_mean = ggml_ext_tensor_mean(tensor); - ggml_ext_tensor_scale_inplace(tensor, (original_mean / new_mean)); - } - - GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx); - - int64_t min_length = 0; - if (sd_version_is_flux2(version)) { - min_length = 512; - } - - int64_t zero_pad_len = 0; - if (min_length > 0) { - if (hidden_states->ne[1] - prompt_template_encode_start_idx < min_length) { - zero_pad_len = min_length - hidden_states->ne[1] + prompt_template_encode_start_idx; - } + GGML_ABORT("unknown version %d", version); + } + + auto hidden_states = encode_prompt(work_ctx, + n_threads, + prompt, + prompt_attn_range, + max_length, + min_length, + image_embeds, + out_layers, + prompt_template_encode_start_idx); + + std::vector extra_hidden_states_vec; + for (int i = 0; i < extra_prompts.size(); i++) { + auto extra_hidden_states = encode_prompt(work_ctx, + n_threads, + extra_prompts[i], + extra_prompts_attn_range[i], + max_length, + min_length, + image_embeds, + out_layers, + prompt_template_encode_start_idx); + extra_hidden_states_vec.push_back(extra_hidden_states); } - ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx, - GGML_TYPE_F32, - hidden_states->ne[0], - hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len, - hidden_states->ne[2]); - - ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) { - float value = 0.f; - if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) { - value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3); - } - ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3); - }); - - // print_ggml_tensor(new_hidden_states); - int64_t t1 = ggml_time_ms(); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); - return {new_hidden_states, nullptr, nullptr}; + return {hidden_states, nullptr, nullptr, extra_hidden_states_vec}; } }; diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 06cbecc28..89a26a4bc 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -23,6 +23,8 @@ struct DiffusionParams { struct ggml_tensor* vace_context = nullptr; float vace_strength = 1.f; std::vector skip_layers = {}; + std::vector extra_contexts; // for z-image-omni + std::vector ref_clip_feats; // for z-image-omni }; struct DiffusionModel { @@ -436,12 +438,14 @@ struct ZImageModel : public DiffusionModel { DiffusionParams diffusion_params, struct ggml_tensor** output = nullptr, struct ggml_context* output_ctx = nullptr) override { + std::vector contexts = {diffusion_params.context}; + contexts.insert(contexts.end(), diffusion_params.extra_contexts.begin(), diffusion_params.extra_contexts.end()); return z_image.compute(n_threads, diffusion_params.x, diffusion_params.timesteps, - diffusion_params.context, + contexts, diffusion_params.ref_latents, - true, // increase_ref_index + diffusion_params.ref_clip_feats, output, output_ctx); } diff --git a/rope.hpp b/rope.hpp index 4e6136c11..eb5326052 100644 --- a/rope.hpp +++ b/rope.hpp @@ -518,60 +518,117 @@ namespace Rope { return (m - (a % m)) % m; } - __STATIC_INLINE__ std::vector> gen_z_image_ids(int h, - int w, + __STATIC_INLINE__ std::vector> gen_z_image_ids(ggml_tensor* x, + const std::vector& contexts, + const std::vector& ref_latents, + const std::vector& siglip_feats, int patch_size, - int bs, - int context_len, int seq_multi_of, - const std::vector& ref_latents, - bool increase_ref_index) { - int padded_context_len = context_len + bound_mod(context_len, seq_multi_of); - auto txt_ids = std::vector>(bs * padded_context_len, std::vector(3, 0.0f)); - for (int i = 0; i < bs * padded_context_len; i++) { - txt_ids[i][0] = (i % padded_context_len) + 1.f; + int bs) { + GGML_ASSERT(contexts.size() > ref_latents.size()); + GGML_ASSERT(contexts.size() >= siglip_feats.size()); + int context_cu_len = 1; + std::vector context_end_pos; + std::vector> txt_ids; + for (auto context : contexts) { + int padded_context_len = context->ne[1] + bound_mod(context->ne[1], seq_multi_of); + auto curr_txt_ids = std::vector>(bs * padded_context_len, std::vector(3, 0.0f)); + for (int i = 0; i < bs * padded_context_len; i++) { + curr_txt_ids[i][0] = static_cast((i % padded_context_len) + context_cu_len); + } + context_cu_len += padded_context_len; + context_end_pos.push_back(context_cu_len); + context_cu_len += 2; // for image and siglip tokens + txt_ids = concat_ids(txt_ids, curr_txt_ids, bs); } - int axes_dim_num = 3; - int index = padded_context_len + 1; - auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, index); + std::vector> img_ids; + std::vector all_img = ref_latents; + all_img.push_back(x); + for (int i = 0; i < all_img.size(); i++) { + int axes_dim_num = 3; + int index = context_end_pos[i]; + auto curr_img_ids = gen_flux_img_ids(all_img[i]->ne[1], all_img[i]->ne[0], patch_size, bs, axes_dim_num, index); + + int img_pad_len = bound_mod(static_cast(curr_img_ids.size() / bs), seq_multi_of); + if (img_pad_len > 0) { + std::vector> img_pad_ids(bs * img_pad_len, std::vector(3, 0.f)); + curr_img_ids = concat_ids(curr_img_ids, img_pad_ids, bs); + } + img_ids = concat_ids(img_ids, curr_img_ids, bs); + } + + std::vector> sig_ids; + for (int i = 0; i < siglip_feats.size(); i++) { + int axes_dim_num = 3; + int index = context_end_pos[i] + 1; + int h_len = siglip_feats[i]->ne[1]; + int w_len = siglip_feats[i]->ne[0]; + + std::vector> curr_sig_ids(bs * h_len * w_len, std::vector(axes_dim_num, 0.0)); + + // scale position IDs to match img resolution + std::vector row_ids = linspace(0, all_img[i]->ne[1] - 1, h_len); + std::vector col_ids = linspace(0, all_img[i]->ne[0] - 1, w_len); + + for (int ib = 0; ib < bs; ++ib) { + for (int ih = 0; ih < h_len; ++ih) { + for (int iw = 0; iw < w_len; ++iw) { + curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][0] = index; + curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][1] = row_ids[ih]; + curr_sig_ids[ib * h_len * w_len + ih * w_len + iw][2] = col_ids[iw]; + } + } + } - int img_pad_len = bound_mod(static_cast(img_ids.size() / bs), seq_multi_of); - if (img_pad_len > 0) { - std::vector> img_pad_ids(bs * img_pad_len, std::vector(3, 0.f)); - img_ids = concat_ids(img_ids, img_pad_ids, bs); + int sig_pad_len = bound_mod(static_cast(curr_sig_ids.size() / bs), seq_multi_of); + if (sig_pad_len > 0) { + std::vector> sig_pad_ids(bs * sig_pad_len, std::vector(3, 0.f)); + curr_sig_ids = concat_ids(curr_sig_ids, sig_pad_ids, bs); + } + sig_ids = concat_ids(sig_ids, curr_sig_ids, bs); } auto ids = concat_ids(txt_ids, img_ids, bs); - // ignore ref_latents for now + if (!sig_ids.empty()) { + ids = concat_ids(ids, sig_ids, bs); + } + return ids; } // Generate z_image positional embeddings - __STATIC_INLINE__ std::vector gen_z_image_pe(int h, - int w, + __STATIC_INLINE__ std::vector gen_z_image_pe(ggml_tensor* x, + const std::vector& contexts, + const std::vector& ref_latents, + const std::vector& siglip_feats, int patch_size, - int bs, - int context_len, int seq_multi_of, - const std::vector& ref_latents, - bool increase_ref_index, int theta, + const std::vector& axes_dim, bool circular_h, bool circular_w, - const std::vector& axes_dim) { - std::vector> ids = gen_z_image_ids(h, w, patch_size, bs, context_len, seq_multi_of, ref_latents, increase_ref_index); + int bs) { + std::vector> ids = gen_z_image_ids(x, contexts, ref_latents, siglip_feats, patch_size, seq_multi_of, bs); std::vector> wrap_dims; if ((circular_h || circular_w) && bs > 0 && axes_dim.size() >= 3) { + int context_len = 0; + for (auto context : contexts) { + int padded_context_len = context->ne[1] + bound_mod(context->ne[1], seq_multi_of); + context_len += padded_context_len; + } + int h = x->ne[1]; + int w = x->ne[0]; int pad_h = (patch_size - (h % patch_size)) % patch_size; int pad_w = (patch_size - (w % patch_size)) % patch_size; int h_len = (h + pad_h) / patch_size; int w_len = (w + pad_w) / patch_size; + if (h_len > 0 && w_len > 0) { size_t pos_len = ids.size() / bs; wrap_dims.assign(axes_dim.size(), std::vector(pos_len, 0)); - size_t cursor = context_len + bound_mod(context_len, seq_multi_of); // skip text (and its padding) + size_t cursor = context_len; // skip text (and its padding) size_t img_tokens = static_cast(h_len) * static_cast(w_len); for (size_t token_i = 0; token_i < img_tokens; ++token_i) { if (circular_h) { diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 4b1c00438..9cd23e959 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -1528,6 +1528,7 @@ class StableDiffusionGGML { int start_merge_step, SDCondition id_cond, std::vector ref_latents = {}, + std::vector ref_clip_feats = {}, bool increase_ref_index = false, ggml_tensor* denoise_mask = nullptr, ggml_tensor* vace_context = nullptr, @@ -1651,7 +1652,7 @@ class StableDiffusionGGML { TaylorSeerConfig tcfg; tcfg.enabled = (cache_params->mode == SD_CACHE_TAYLORSEER || - cache_params->mode == SD_CACHE_CACHE_DIT); + cache_params->mode == SD_CACHE_CACHE_DIT); tcfg.n_derivatives = cache_params->taylorseer_n_derivatives; tcfg.skip_interval_steps = cache_params->taylorseer_skip_interval; @@ -1921,6 +1922,7 @@ class StableDiffusionGGML { diffusion_params.timesteps = timesteps; diffusion_params.guidance = guidance_tensor; diffusion_params.ref_latents = ref_latents; + diffusion_params.ref_clip_feats = ref_clip_feats; diffusion_params.increase_ref_index = increase_ref_index; diffusion_params.controls = controls; diffusion_params.control_strength = control_strength; @@ -1931,10 +1933,11 @@ class StableDiffusionGGML { struct ggml_tensor** active_output = &out_cond; if (start_merge_step == -1 || step <= start_merge_step) { // cond - diffusion_params.context = cond.c_crossattn; - diffusion_params.c_concat = cond.c_concat; - diffusion_params.y = cond.c_vector; - active_condition = &cond; + diffusion_params.context = cond.c_crossattn; + diffusion_params.extra_contexts = cond.extra_c_crossattns; + diffusion_params.c_concat = cond.c_concat; + diffusion_params.y = cond.c_vector; + active_condition = &cond; } else { diffusion_params.context = id_cond.c_crossattn; diffusion_params.c_concat = cond.c_concat; @@ -1965,12 +1968,13 @@ class StableDiffusionGGML { LOG_ERROR("controlnet compute failed"); } } - current_step_skipped = cache_step_is_skipped(); - diffusion_params.controls = controls; - diffusion_params.context = uncond.c_crossattn; - diffusion_params.c_concat = uncond.c_concat; - diffusion_params.y = uncond.c_vector; - bool skip_uncond = cache_before_condition(&uncond, out_uncond); + current_step_skipped = cache_step_is_skipped(); + diffusion_params.controls = controls; + diffusion_params.context = uncond.c_crossattn; + diffusion_params.extra_contexts = uncond.extra_c_crossattns; + diffusion_params.c_concat = uncond.c_concat; + diffusion_params.y = uncond.c_vector; + bool skip_uncond = cache_before_condition(&uncond, out_uncond); if (!skip_uncond) { if (!work_diffusion_model->compute(n_threads, diffusion_params, @@ -1985,10 +1989,11 @@ class StableDiffusionGGML { float* img_cond_data = nullptr; if (has_img_cond) { - diffusion_params.context = img_cond.c_crossattn; - diffusion_params.c_concat = img_cond.c_concat; - diffusion_params.y = img_cond.c_vector; - bool skip_img_cond = cache_before_condition(&img_cond, out_img_cond); + diffusion_params.context = img_cond.c_crossattn; + diffusion_params.extra_contexts = img_cond.extra_c_crossattns; + diffusion_params.c_concat = img_cond.c_concat; + diffusion_params.y = img_cond.c_vector; + bool skip_img_cond = cache_before_condition(&img_cond, out_img_cond); if (!skip_img_cond) { if (!work_diffusion_model->compute(n_threads, diffusion_params, @@ -3100,6 +3105,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, sd_pm_params_t pm_params, std::vector ref_images, std::vector ref_latents, + std::vector ref_clip_feats, bool increase_ref_index, ggml_tensor* concat_latent = nullptr, ggml_tensor* denoise_mask = nullptr, @@ -3388,6 +3394,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, start_merge_step, id_cond, ref_latents, + ref_clip_feats, increase_ref_index, denoise_mask, nullptr, @@ -3654,6 +3661,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g } std::vector ref_latents; + std::vector ref_clip_feats; for (int i = 0; i < ref_images.size(); i++) { ggml_tensor* img; if (sd_img_gen_params->auto_resize_ref_image) { @@ -3700,6 +3708,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g ggml_tensor* latent = sd_ctx->sd->encode_first_stage(work_ctx, img); ref_latents.push_back(latent); + + auto clip_vision_output = sd_ctx->sd->get_clip_vision_output(work_ctx, *ref_images[i], false, -2); + ref_clip_feats.push_back(clip_vision_output); } if (sd_img_gen_params->init_image.data != nullptr || sd_img_gen_params->ref_images_count > 0) { @@ -3727,6 +3738,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g sd_img_gen_params->pm_params, ref_images, ref_latents, + ref_clip_feats, sd_img_gen_params->increase_ref_index, concat_latent, denoise_mask, @@ -4095,8 +4107,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s high_noise_sample_method, high_noise_sigmas, -1, - {}, - {}, + {}, // id_cond + {}, // ref_latents + {}, // ref_clip_feats false, denoise_mask, vace_context, @@ -4132,8 +4145,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s sample_method, sigmas, -1, - {}, - {}, + {}, // id_cond + {}, // ref_latents + {}, // ref_clip_feats false, denoise_mask, vace_context, diff --git a/z_image.hpp b/z_image.hpp index af8d57e04..1f34c9fab 100644 --- a/z_image.hpp +++ b/z_image.hpp @@ -118,14 +118,37 @@ namespace ZImage { __STATIC_INLINE__ struct ggml_tensor* modulate(struct ggml_context* ctx, struct ggml_tensor* x, - struct ggml_tensor* scale) { + struct ggml_tensor* scale, + bool skip_reshape = false) { // x: [N, L, C] - // scale: [N, C] - scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] - x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); + // scale: [N, C] or [N, L, C] + if (!skip_reshape) { + scale = ggml_reshape_3d(ctx, scale, scale->ne[0], 1, scale->ne[1]); // [N, 1, C] + } + x = ggml_add(ctx, x, ggml_mul(ctx, x, scale)); return x; } + __STATIC_INLINE__ struct ggml_tensor* select_per_token(struct ggml_context* ctx, + struct ggml_tensor* index, + struct ggml_tensor* mod_0, + struct ggml_tensor* mod_1) { + // index: [N, L] + // mod_0/mod_1: [N, C] + // return: [N, L, C] + // mod_result = torch.where(index == 0, mod_0, mod_1) + // mod_result = (1 - index)*mod_0 + index*mod_1 + index = ggml_reshape_3d(ctx, index, 1, index->ne[0], index->ne[1]); + index = ggml_repeat_4d(ctx, index, mod_0->ne[0], index->ne[1], index->ne[2], 1); // [N, L, C] + mod_0 = ggml_reshape_3d(ctx, mod_0, mod_0->ne[0], 1, mod_0->ne[1]); // [N, 1, C] + mod_1 = ggml_reshape_3d(ctx, mod_1, mod_1->ne[0], 1, mod_1->ne[1]); // [N, 1, C] + + mod_0 = ggml_sub(ctx, ggml_repeat(ctx, mod_0, index), ggml_mul(ctx, index, mod_0)); // [N, L, C] + mod_1 = ggml_mul(ctx, index, mod_1); // [N, L, C] + auto mod_result = ggml_add(ctx, mod_0, mod_1); + return mod_result; + } + struct JointTransformerBlock : public GGMLBlock { protected: bool modulation; @@ -157,7 +180,10 @@ namespace ZImage { struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask = nullptr, - struct ggml_tensor* adaln_input = nullptr) { + struct ggml_tensor* adaln_input = nullptr, + struct ggml_tensor* noise_mask = nullptr, + struct ggml_tensor* adaln_noisy = nullptr, + struct ggml_tensor* adaln_clean = nullptr) { auto attention = std::dynamic_pointer_cast(blocks["attention"]); auto feed_forward = std::dynamic_pointer_cast(blocks["feed_forward"]); auto attention_norm1 = std::dynamic_pointer_cast(blocks["attention_norm1"]); @@ -166,32 +192,55 @@ namespace ZImage { auto ffn_norm2 = std::dynamic_pointer_cast(blocks["ffn_norm2"]); if (modulation) { - GGML_ASSERT(adaln_input != nullptr); auto adaLN_modulation_0 = std::dynamic_pointer_cast(blocks["adaLN_modulation.0"]); - auto m = adaLN_modulation_0->forward(ctx, adaln_input); // [N, 4 * hidden_size] - auto mods = ggml_ext_chunk(ctx->ggml_ctx, m, 4, 0); - auto scale_msa = mods[0]; - auto gate_msa = mods[1]; - auto scale_mlp = mods[2]; - auto gate_mlp = mods[3]; + struct ggml_tensor* scale_msa = nullptr; + struct ggml_tensor* gate_msa = nullptr; + struct ggml_tensor* scale_mlp = nullptr; + struct ggml_tensor* gate_mlp = nullptr; + bool skip_reshape = false; + + if (noise_mask != nullptr) { + GGML_ASSERT(adaln_noisy != nullptr); + GGML_ASSERT(adaln_clean != nullptr); + + auto mod_noisy = adaLN_modulation_0->forward(ctx, adaln_noisy); // [N, 4 * hidden_size] + auto mod_clean = adaLN_modulation_0->forward(ctx, adaln_clean); // [N, 4 * hidden_size] + + auto mod_noisy_vec = ggml_ext_chunk(ctx->ggml_ctx, mod_noisy, 4, 0); + auto mod_clean_vec = ggml_ext_chunk(ctx->ggml_ctx, mod_clean, 4, 0); + + scale_msa = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[0], mod_noisy_vec[0]); + gate_msa = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[1], mod_noisy_vec[1]); + scale_mlp = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[2], mod_noisy_vec[2]); + gate_mlp = select_per_token(ctx->ggml_ctx, noise_mask, mod_clean_vec[3], mod_noisy_vec[3]); + + skip_reshape = true; + } else { + GGML_ASSERT(adaln_input != nullptr); + + auto mod = adaLN_modulation_0->forward(ctx, adaln_input); // [N, 4 * hidden_size] + auto mod_vec = ggml_ext_chunk(ctx->ggml_ctx, mod, 4, 0); + scale_msa = mod_vec[0]; + gate_msa = mod_vec[1]; + scale_mlp = mod_vec[2]; + gate_mlp = mod_vec[3]; + } auto residual = x; - x = modulate(ctx->ggml_ctx, attention_norm1->forward(ctx, x), scale_msa); + x = modulate(ctx->ggml_ctx, attention_norm1->forward(ctx, x), scale_msa, skip_reshape); x = attention->forward(ctx, x, pe, mask); x = attention_norm2->forward(ctx, x); x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_msa)); x = ggml_add(ctx->ggml_ctx, x, residual); residual = x; - x = modulate(ctx->ggml_ctx, ffn_norm1->forward(ctx, x), scale_mlp); + x = modulate(ctx->ggml_ctx, ffn_norm1->forward(ctx, x), scale_mlp, skip_reshape); x = feed_forward->forward(ctx, x); x = ffn_norm2->forward(ctx, x); x = ggml_mul(ctx->ggml_ctx, x, ggml_tanh(ctx->ggml_ctx, gate_mlp)); x = ggml_add(ctx->ggml_ctx, x, residual); } else { - GGML_ASSERT(adaln_input == nullptr); - auto residual = x; x = attention_norm1->forward(ctx, x); x = attention->forward(ctx, x, pe, mask); @@ -221,7 +270,10 @@ namespace ZImage { struct ggml_tensor* forward(GGMLRunnerContext* ctx, struct ggml_tensor* x, - struct ggml_tensor* c) { + struct ggml_tensor* c, + struct ggml_tensor* noise_mask = nullptr, + struct ggml_tensor* c_noisy = nullptr, + struct ggml_tensor* c_clean = nullptr) { // x: [N, n_token, hidden_size] // c: [N, hidden_size] // return: [N, n_token, patch_size * patch_size * out_channels] @@ -229,10 +281,28 @@ namespace ZImage { auto linear = std::dynamic_pointer_cast(blocks["linear"]); auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - auto scale = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, hidden_size] - x = norm_final->forward(ctx, x); - x = modulate(ctx->ggml_ctx, x, scale); - x = linear->forward(ctx, x); + struct ggml_tensor* scale = nullptr; + bool skip_reshape = false; + + if (noise_mask != nullptr) { + GGML_ASSERT(c_noisy != nullptr); + GGML_ASSERT(c_clean != nullptr); + + auto scale_noisy = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c_noisy)); // [N, hidden_size] + auto scale_clean = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c_clean)); // [N, hidden_size] + + scale = select_per_token(ctx->ggml_ctx, noise_mask, scale_clean, scale_noisy); + + skip_reshape = true; + } else { + GGML_ASSERT(c != nullptr); + + scale = adaLN_modulation_1->forward(ctx, ggml_silu(ctx->ggml_ctx, c)); // [N, hidden_size] + } + + x = norm_final->forward(ctx, x); + x = modulate(ctx->ggml_ctx, x, scale, skip_reshape); + x = linear->forward(ctx, x); return x; } @@ -253,6 +323,7 @@ namespace ZImage { float norm_eps = 1e-5f; bool qk_norm = true; int64_t cap_feat_dim = 2560; + int64_t siglip_feat_dim = 0; float theta = 256.f; std::vector axes_dim = {32, 48, 48}; int64_t axes_dim_sum = 128; @@ -265,6 +336,10 @@ namespace ZImage { void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") override { params["cap_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size); params["x_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size); + + if (z_image_params.siglip_feat_dim > 0) { + params["siglip_pad_token"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, z_image_params.hidden_size); + } } public: @@ -306,6 +381,26 @@ namespace ZImage { blocks["context_refiner." + std::to_string(i)] = block; } + if (z_image_params.siglip_feat_dim > 0) { + blocks["siglip_embedder.0"] = std::make_shared(z_image_params.siglip_feat_dim, z_image_params.norm_eps); + blocks["siglip_embedder.1"] = std::make_shared(z_image_params.siglip_feat_dim, z_image_params.hidden_size); + + for (int i = 0; i < z_image_params.num_refiner_layers; i++) { + auto block = std::make_shared(2000 + i, + z_image_params.hidden_size, + z_image_params.head_dim, + z_image_params.num_heads, + z_image_params.num_kv_heads, + z_image_params.multiple_of, + z_image_params.ffn_dim_multiplier, + z_image_params.norm_eps, + z_image_params.qk_norm, + false); + + blocks["siglip_refiner." + std::to_string(i)] = block; + } + } + for (int i = 0; i < z_image_params.num_layers; i++) { auto block = std::make_shared(i, z_image_params.hidden_size, @@ -387,11 +482,32 @@ namespace ZImage { return x; } - struct ggml_tensor* forward_core(GGMLRunnerContext* ctx, - struct ggml_tensor* x, - struct ggml_tensor* timestep, - struct ggml_tensor* context, - struct ggml_tensor* pe) { + std::pair _pad_and_gen_noise_mask(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* pad_token, + int N, + float noise_mask_value = 1.f) { + int64_t n_pad_token = Rope::bound_mod(x->ne[1], SEQ_MULTI_OF); + if (n_pad_token > 0) { + auto pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, pad_token, pad_token->ne[0], n_pad_token, N, 1); + x = ggml_concat(ctx->ggml_ctx, x, pad_tokens, 1); // [N, n_token + n_pad_token, hidden_size] + } + ggml_tensor* noise_mask = nullptr; + if (noise_mask_value == 0.f) { + noise_mask = ggml_ext_zeros(ctx->ggml_ctx, x->ne[1], 1, 1, 1); + } else if (noise_mask_value == 1.f) { + noise_mask = ggml_ext_ones(ctx->ggml_ctx, x->ne[1], 1, 1, 1); + } + return {x, noise_mask}; + } + + struct ggml_tensor* forward_omni(GGMLRunnerContext* ctx, + ggml_tensor* x, + ggml_tensor* timestep, + std::vector contexts, + ggml_tensor* pe, + std::vector ref_latents, + std::vector siglip_feats) { auto x_embedder = std::dynamic_pointer_cast(blocks["x_embedder"]); auto t_embedder = std::dynamic_pointer_cast(blocks["t_embedder"]); auto cap_embedder_0 = std::dynamic_pointer_cast(blocks["cap_embedder.0"]); @@ -402,31 +518,145 @@ namespace ZImage { auto txt_pad_token = params["cap_pad_token"]; auto img_pad_token = params["x_pad_token"]; - int64_t N = x->ne[2]; - int64_t n_img_token = x->ne[1]; - int64_t n_txt_token = context->ne[1]; + bool omni_mode = ref_latents.size() > 0; + + int64_t N = x->ne[2]; + + // noise mask of img: 0 for condition images (clean), 1 for target image (noisy) + // noise mask of txg/sig: same as the corresponding img. If there is no corresponding img, set to 1 + + ggml_tensor* txt = nullptr; + ggml_tensor* txt_noise_mask = nullptr; + for (int i = 0; i < contexts.size(); i++) { + auto curr_txt_raw = cap_embedder_1->forward(ctx, cap_embedder_0->forward(ctx, contexts[i])); // [N, n_txt_token, hidden_size] + + float noise_mask_value = -1.f; // empty noise mask + if (omni_mode) { + noise_mask_value = (i < ref_latents.size() ? 0.f : 1.f); + } + + auto [curr_txt, curr_txt_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_txt_raw, txt_pad_token, N, noise_mask_value); + if (txt == nullptr) { + txt = curr_txt; + } else { + txt = ggml_concat(ctx->ggml_ctx, txt, curr_txt, 1); + } + + if (omni_mode) { + if (txt_noise_mask == nullptr) { + txt_noise_mask = curr_txt_noise_mask; + } else { + txt_noise_mask = ggml_concat(ctx->ggml_ctx, txt_noise_mask, curr_txt_noise_mask, 0); + } + } + } + + ggml_tensor* img = nullptr; + ggml_tensor* img_noise_mask = nullptr; + for (ggml_tensor* ref : ref_latents) { + auto curr_img_raw = x_embedder->forward(ctx, ref); // [N, n_img_token, hidden_size] + + float noise_mask_value = -1.f; // empty noise mask + if (omni_mode) { + noise_mask_value = 0.f; + } + + auto [curr_img, curr_img_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_img_raw, img_pad_token, N, noise_mask_value); + if (img == nullptr) { + img = curr_img; + } else { + img = ggml_concat(ctx->ggml_ctx, img, curr_img, 1); + } + + if (omni_mode) { + if (img_noise_mask == nullptr) { + img_noise_mask = curr_img_noise_mask; + } else { + img_noise_mask = ggml_concat(ctx->ggml_ctx, img_noise_mask, curr_img_noise_mask, 0); + } + } + } + + int64_t final_img_offset = (img ? img->ne[1] : 0); + int64_t final_img_pad_len = 0; + + { + auto curr_img_raw = x_embedder->forward(ctx, x); // [N, n_img_token, hidden_size] + + float noise_mask_value = -1.f; // empty noise mask + if (omni_mode) { + noise_mask_value = 0.f; + } + + auto [curr_img, curr_img_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_img_raw, img_pad_token, N, noise_mask_value); + if (img == nullptr) { + img = curr_img; + } else { + img = ggml_concat(ctx->ggml_ctx, img, curr_img, 1); + } + + if (omni_mode) { + if (img_noise_mask == nullptr) { + img_noise_mask = curr_img_noise_mask; + } else { + img_noise_mask = ggml_concat(ctx->ggml_ctx, img_noise_mask, curr_img_noise_mask, 0); + } + } + + final_img_pad_len = Rope::bound_mod(curr_img_raw->ne[1], SEQ_MULTI_OF); + } + + ggml_tensor* sig = nullptr; + ggml_tensor* sig_noise_mask = nullptr; + for (int i = 0; i < siglip_feats.size(); i++) { + auto sig_pad_token = params["siglip_pad_token"]; + auto siglip_embedder_0 = std::dynamic_pointer_cast(blocks["siglip_embedder.0"]); + auto siglip_embedder_1 = std::dynamic_pointer_cast(blocks["siglip_embedder.1"]); + + auto curr_sig_raw = siglip_embedder_1->forward(ctx, siglip_embedder_0->forward(ctx, siglip_feats[i])); // [N, n_sig_token, hidden_size] - auto t_emb = t_embedder->forward(ctx, timestep); + float noise_mask_value = -1.f; // empty noise mask + if (omni_mode) { + noise_mask_value = (i < ref_latents.size() ? 0.f : 1.f); + } - auto txt = cap_embedder_1->forward(ctx, cap_embedder_0->forward(ctx, context)); // [N, n_txt_token, hidden_size] - auto img = x_embedder->forward(ctx, x); // [N, n_img_token, hidden_size] + auto [curr_sig, curr_sig_noise_mask] = _pad_and_gen_noise_mask(ctx, curr_sig_raw, sig_pad_token, N, noise_mask_value); + if (sig == nullptr) { + sig = curr_sig; + } else { + sig = ggml_concat(ctx->ggml_ctx, sig, curr_sig, 1); + } - int64_t n_txt_pad_token = Rope::bound_mod(n_txt_token, SEQ_MULTI_OF); - if (n_txt_pad_token > 0) { - auto txt_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, txt_pad_token, txt_pad_token->ne[0], n_txt_pad_token, N, 1); - txt = ggml_concat(ctx->ggml_ctx, txt, txt_pad_tokens, 1); // [N, n_txt_token + n_txt_pad_token, hidden_size] + if (omni_mode) { + if (sig_noise_mask == nullptr) { + sig_noise_mask = curr_sig_noise_mask; + } else { + sig_noise_mask = ggml_concat(ctx->ggml_ctx, sig_noise_mask, curr_sig_noise_mask, 0); + } + } } - int64_t n_img_pad_token = Rope::bound_mod(n_img_token, SEQ_MULTI_OF); - if (n_img_pad_token > 0) { - auto img_pad_tokens = ggml_repeat_4d(ctx->ggml_ctx, img_pad_token, img_pad_token->ne[0], n_img_pad_token, N, 1); - img = ggml_concat(ctx->ggml_ctx, img, img_pad_tokens, 1); // [N, n_img_token + n_img_pad_token, hidden_size] + ggml_tensor* t_emb = nullptr; + ggml_tensor* t_noisy = nullptr; + ggml_tensor* t_clean = nullptr; + if (omni_mode) { + t_noisy = t_embedder->forward(ctx, timestep); + t_clean = t_embedder->forward(ctx, + ggml_scale(ctx->ggml_ctx, + ggml_ext_ones(ctx->ggml_ctx, timestep->ne[0], timestep->ne[1], timestep->ne[2], timestep->ne[3]), + 1000.f)); + } else { + t_emb = t_embedder->forward(ctx, timestep); } - GGML_ASSERT(txt->ne[1] + img->ne[1] == pe->ne[3]); + if (sig) { + GGML_ASSERT(txt->ne[1] + img->ne[1] + sig->ne[1] == pe->ne[3]); + } else { + GGML_ASSERT(txt->ne[1] + img->ne[1] == pe->ne[3]); + } auto txt_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, 0, txt->ne[1]); - auto img_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1], pe->ne[3]); + auto img_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1], txt->ne[1] + img->ne[1]); for (int i = 0; i < z_image_params.num_refiner_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["context_refiner." + std::to_string(i)]); @@ -437,30 +667,50 @@ namespace ZImage { for (int i = 0; i < z_image_params.num_refiner_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["noise_refiner." + std::to_string(i)]); - img = block->forward(ctx, img, img_pe, nullptr, t_emb); + img = block->forward(ctx, img, img_pe, nullptr, t_emb, img_noise_mask, t_noisy, t_clean); } - auto txt_img = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_txt_pad_token + n_img_token + n_img_pad_token, hidden_size] + auto unified = ggml_concat(ctx->ggml_ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size] + + ggml_tensor* noise_mask = nullptr; + if (omni_mode) { + noise_mask = ggml_concat(ctx->ggml_ctx, txt_noise_mask, img_noise_mask, 0); // [N, n_txt_token + n_img_token] + } + + ggml_tensor* sig_pe = nullptr; + if (sig) { + sig_pe = ggml_ext_slice(ctx->ggml_ctx, pe, 3, txt->ne[1] + img->ne[1], pe->ne[3]); + + for (int i = 0; i < z_image_params.num_refiner_layers; i++) { + auto block = std::dynamic_pointer_cast(blocks["siglip_refiner." + std::to_string(i)]); + + sig = block->forward(ctx, sig, sig_pe, nullptr, nullptr); + } + + unified = ggml_concat(ctx->ggml_ctx, unified, sig, 1); // [N, n_txt_token + n_img_token + n_sig_token, hidden_size] + noise_mask = ggml_concat(ctx->ggml_ctx, noise_mask, sig_noise_mask, 0); // [N, n_txt_token + n_img_token + n_sig_token] + } for (int i = 0; i < z_image_params.num_layers; i++) { auto block = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); - txt_img = block->forward(ctx, txt_img, pe, nullptr, t_emb); + unified = block->forward(ctx, unified, pe, nullptr, t_emb, noise_mask, t_noisy, t_clean); } - txt_img = final_layer->forward(ctx, txt_img, t_emb); // [N, n_txt_token + n_txt_pad_token + n_img_token + n_img_pad_token, ph*pw*C] + unified = final_layer->forward(ctx, unified, t_emb, noise_mask, t_noisy, t_clean); // [N, n_txt_token + n_img_token + n_sig_token, ph*pw*C] - img = ggml_ext_slice(ctx->ggml_ctx, txt_img, 1, n_txt_token + n_txt_pad_token, n_txt_token + n_txt_pad_token + n_img_token); // [N, n_img_token, ph*pw*C] + img = ggml_ext_slice(ctx->ggml_ctx, unified, 1, txt->ne[1] + final_img_offset, txt->ne[1] + img->ne[1] - final_img_pad_len); // [N, n_final_img_token, ph*pw*C] return img; } struct ggml_tensor* forward(GGMLRunnerContext* ctx, - struct ggml_tensor* x, - struct ggml_tensor* timestep, - struct ggml_tensor* context, - struct ggml_tensor* pe, - std::vector ref_latents = {}) { + ggml_tensor* x, + ggml_tensor* timestep, + std::vector contexts, + ggml_tensor* pe, + std::vector ref_latents = {}, + std::vector siglip_feats = {}) { // Forward pass of DiT. // x: [N, C, H, W] // timestep: [N,] @@ -473,23 +723,20 @@ namespace ZImage { int64_t C = x->ne[2]; int64_t N = x->ne[3]; - auto img = process_img(ctx, x); - uint64_t n_img_token = img->ne[1]; - - if (ref_latents.size() > 0) { - for (ggml_tensor* ref : ref_latents) { - ref = process_img(ctx, ref); - img = ggml_concat(ctx->ggml_ctx, img, ref, 1); - } - } + auto img = process_img(ctx, x); int64_t h_len = ((H + (z_image_params.patch_size / 2)) / z_image_params.patch_size); int64_t w_len = ((W + (z_image_params.patch_size / 2)) / z_image_params.patch_size); - auto out = forward_core(ctx, img, timestep, context, pe); + for (int i = 0; i < ref_latents.size(); i++) { + ref_latents[i] = process_img(ctx, ref_latents[i]); + } + + auto out = forward_omni(ctx, img, timestep, contexts, pe, ref_latents, siglip_feats); // [N, n_img_token, ph*pw*C] + + // auto out = forward_basic(ctx, img, timestep, contexts[0], pe); // [N, n_img_token, ph*pw*C] - out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, n_img_token); // [N, n_img_token, ph*pw*C] - out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] + out = unpatchify(ctx->ggml_ctx, out, h_len, w_len); // [N, C, H + pad_h, W + pad_w] // slice out = ggml_ext_slice(ctx->ggml_ctx, out, 1, 0, H); // [N, C, H, W + pad_w] @@ -527,34 +774,37 @@ namespace ZImage { z_image.get_param_tensors(tensors, prefix); } - struct ggml_cgraph* build_graph(struct ggml_tensor* x, - struct ggml_tensor* timesteps, - struct ggml_tensor* context, - std::vector ref_latents = {}, - bool increase_ref_index = false) { + struct ggml_cgraph* build_graph(ggml_tensor* x, + ggml_tensor* timesteps, + std::vector contexts, + std::vector ref_latents = {}, + std::vector siglip_feats = {}) { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = new_graph_custom(Z_IMAGE_GRAPH_SIZE); - x = to_backend(x); - context = to_backend(context); + x = to_backend(x); + + for (int i = 0; i < contexts.size(); i++) { + contexts[i] = to_backend(contexts[i]); + } + timesteps = to_backend(timesteps); for (int i = 0; i < ref_latents.size(); i++) { ref_latents[i] = to_backend(ref_latents[i]); } - pe_vec = Rope::gen_z_image_pe(x->ne[1], - x->ne[0], + pe_vec = Rope::gen_z_image_pe(x, + contexts, + ref_latents, + siglip_feats, z_image_params.patch_size, - x->ne[3], - context->ne[1], SEQ_MULTI_OF, - ref_latents, - increase_ref_index, z_image_params.theta, + z_image_params.axes_dim, circular_y_enabled, circular_x_enabled, - z_image_params.axes_dim); + x->ne[3]); int pos_len = pe_vec.size() / z_image_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, z_image_params.axes_dim_sum / 2, pos_len); @@ -567,7 +817,7 @@ namespace ZImage { struct ggml_tensor* out = z_image.forward(&runner_ctx, x, timesteps, - context, + contexts, pe, ref_latents); @@ -579,16 +829,16 @@ namespace ZImage { bool compute(int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, - struct ggml_tensor* context, - std::vector ref_latents = {}, - bool increase_ref_index = false, - struct ggml_tensor** output = nullptr, - struct ggml_context* output_ctx = nullptr) { + std::vector contexts, + std::vector ref_latents = {}, + std::vector siglip_feats = {}, + struct ggml_tensor** output = nullptr, + struct ggml_context* output_ctx = nullptr) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, ref_latents, increase_ref_index); + return build_graph(x, timesteps, contexts, ref_latents, siglip_feats); }; return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -620,7 +870,7 @@ namespace ZImage { struct ggml_tensor* out = nullptr; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, {}, false, &out, work_ctx); + compute(8, x, timesteps, {context}, {}, {}, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out);