Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lib/op-attrs/src/op-attrs/ops/element_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ ParallelTensorDimDegrees get_output_parallel_dim_degrees(
ElementUnaryAttrs const &attrs,
ParallelTensorDimDegrees const &input_degrees) {
ASSERT(input_degrees.sum_degree.value == 1);
ASSERT(input_degrees.discard_copy_degree.value == 1);

return input_degrees;
}
Expand Down
8 changes: 0 additions & 8 deletions lib/op-attrs/test/src/op-attrs/ops/element_unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,5 @@ TEST_SUITE(FF_TEST_SUITE) {
SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p)));
}

SUBCASE("discard copy degree > 1") {
positive_int degree = 2_p;

CHECK_THROWS(get_output_shape(
attrs,
make_input(
SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p)));
}
}
}
19 changes: 11 additions & 8 deletions lib/realm-execution/include/realm-execution/realm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,18 @@ struct RealmContext {
int priority = 0);
///\}

/** \name Data movement */
/** \name Data movement and reduction */
///\{
Realm::Event issue_copy(ParallelTensorShape const &src_shape,
Realm::RegionInstance src_inst,
ParallelTensorShape const &dst_shape,
Realm::RegionInstance dst_inst,
Realm::ProfilingRequestSet const &requests,
Realm::Event wait_on = Realm::Event::NO_EVENT,
int priority = 0);
Realm::Event
issue_copy(ParallelTensorShape const &src_shape,
Realm::RegionInstance src_inst,
ParallelTensorShape const &dst_shape,
Realm::RegionInstance dst_inst,
Realm::ProfilingRequestSet const &requests,
Realm::Event wait_on = Realm::Event::NO_EVENT,
int priority = 0,
std::optional<Realm::ReductionOpID> redop_id = std::nullopt,
bool exclusive = false);
///\}

/** \name Instance management */
Expand Down
109 changes: 109 additions & 0 deletions lib/realm-execution/include/realm-execution/tasks/realm_reduction.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#pragma once
#include "op-attrs/datatype.dtg.h"
#include <realm.h>

namespace FlexFlow {

// Sum reduction for float
struct SumReductionFloat {
using LHS = float;
using RHS = float;
static constexpr RHS identity = 0.0f; // ← inside struct, constexpr

template <bool EXCLUSIVE>
static void apply(LHS &lhs, RHS rhs) {
if (EXCLUSIVE) {
lhs += rhs;
} else {
// atomic add for non-exclusive
__sync_fetch_and_add((int *)&lhs, *(int *)&rhs);
// proper float atomic add — use union trick
union {
float f;
int i;
} old_val, new_val;
do {
old_val.f = lhs;
new_val.f = old_val.f + rhs;
} while (
!__sync_bool_compare_and_swap((int *)&lhs, old_val.i, new_val.i));
}
}

template <bool EXCLUSIVE>
static void fold(RHS &rhs1, RHS rhs2) {
if (EXCLUSIVE) {
rhs1 += rhs2;
} else {
union {
float f;
int i;
} old_val, new_val;
do {
old_val.f = rhs1;
new_val.f = old_val.f + rhs2;
} while (
!__sync_bool_compare_and_swap((int *)&rhs1, old_val.i, new_val.i));
}
}
};

// Sum reduction for double
struct SumReductionDouble {
using LHS = double;
using RHS = double;
static constexpr RHS identity = 0.0; // ← inside struct, constexpr

template <bool EXCLUSIVE>
static void apply(LHS &lhs, RHS rhs) {
if (EXCLUSIVE) {
lhs += rhs;
} else {
union {
double d;
long long i;
} old_val, new_val;
do {
old_val.d = lhs;
new_val.d = old_val.d + rhs;
} while (!__sync_bool_compare_and_swap(
(long long *)&lhs, old_val.i, new_val.i));
}
}

template <bool EXCLUSIVE>
static void fold(RHS &rhs1, RHS rhs2) {
if (EXCLUSIVE) {
rhs1 += rhs2;
} else {
union {
double d;
long long i;
} old_val, new_val;
do {
old_val.d = rhs1;
new_val.d = old_val.d + rhs2;
} while (!__sync_bool_compare_and_swap(
(long long *)&rhs1, old_val.i, new_val.i));
}
}
};

// Reduction op IDs — must not conflict with other registered redops
enum SumReductionOpIDs {
REDOP_SUM_FLOAT = 1,
REDOP_SUM_DOUBLE = 2,
};

inline Realm::ReductionOpID get_sum_reduction_op_id(DataType dtype) {
switch (dtype) {
case DataType::FLOAT:
return REDOP_SUM_FLOAT;
case DataType::DOUBLE:
return REDOP_SUM_DOUBLE;
default:
PANIC("no sum reduction registered for datatype {}", dtype);
}
}

} // namespace FlexFlow
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization(
std::unordered_map<DynamicNodeInvocation,
DeviceSpecificPtr<PerDeviceOpState> *>
device_state_map;
std::vector<Realm::Event> completion_events;
for (DynamicNodeInvocation const &invocation : dg.invocations) {
Realm::Processor target_proc = ctx.map_device_coord_to_processor(
assert_unwrap(invocation.node_attrs.device_coord));
Expand All @@ -56,14 +57,17 @@ PerDeviceOpStateBacking perform_distributed_per_device_op_state_initialization(
precondition);

if (completion_event.has_value()) {
completion_events.push_back(completion_event.value());
device_state_map.insert(std::pair{invocation, device_state_ptr});
} else {
// Task doesn't require initialization, clean up and don't store result
delete device_state_ptr;
}
}

ctx.get_outstanding_events().wait();
// wait for all init tasks — direct write to *result_ptr happens
// before each init task event fires so result is ready after this
Realm::Event::merge_events(completion_events).wait();

auto deref = [](DeviceSpecificPtr<PerDeviceOpState> *const &p) { return *p; };
std::unordered_map<DynamicNodeInvocation, DeviceSpecificPtr<PerDeviceOpState>>
Expand Down
54 changes: 54 additions & 0 deletions lib/realm-execution/src/realm-execution/pcg_instance.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "realm-execution/instance_allocation.h"
#include "realm-execution/realm_context.h"
#include "realm-execution/tasks/impl/op_task.h"
#include "realm-execution/tasks/realm_reduction.h"
#include "realm-execution/tensor_instance_backing.h"
#include "task-spec/dynamic_graph/copy_insertion.h"
#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h"
Expand Down Expand Up @@ -215,18 +216,71 @@ static Realm::Event spawn_dynamic_node_invocation(
precondition);
};

// issue_replicate_bwd lambda
auto issue_replicate_bwd = [&]() {
std::optional<DynamicValueAttrs> output_grad_opt;
for (auto const &[slot, value] : invocation.inputs) {
if (slot.slot_tensor_role == DynamicTensorRole{FwbTensorType::GRADIENT}) {
output_grad_opt = value;
}
}
DynamicValueAttrs output_grad = assert_unwrap(output_grad_opt);
DynamicValueAttrs input_grad = get_only(invocation.outputs).second;
Realm::RegionInstance dst_inst =
tensor_instance_backing.backing.at(input_grad).first;

Realm::ReductionOpID redop_id = get_sum_reduction_op_id(
assert_unwrap(output_grad.parallel_tensor_shape).data_type);

// chain reductions sequentially to avoid write races on dst
Realm::Event e = precondition;
for (auto const &[p, m] : assert_unwrap(output_grad.mapping)) {
DynamicValueAttrs replica_key = output_grad;
replica_key.mapping =
bidict<ParallelTensorSpaceCoordinate, MachineSpaceCoordinate>{{p, m}};
replica_key.shard_coord = p;

Realm::RegionInstance src_inst =
tensor_instance_backing.backing.at(replica_key).first;

e = ctx.issue_copy(assert_unwrap(output_grad.parallel_tensor_shape),
src_inst,
assert_unwrap(input_grad.parallel_tensor_shape),
dst_inst,
Realm::ProfilingRequestSet{},
e,
0,
redop_id,
false);
}
return e;
};

TrainingOperationAttrs op_attrs =
assert_unwrap(invocation.node_attrs.op_attrs);
return op_attrs.visit<Realm::Event>(overload{
[&](PCGOperatorAttrs const &pcg_op_attrs) {
return pcg_op_attrs.visit<Realm::Event>(overload{
[&](InputAttrs const &) { return Realm::Event::NO_EVENT; },
[&](WeightAttrs const &) { return Realm::Event::NO_EVENT; },
[&](ReplicateAttrs const &) {
// this should never be reached since replicate
// goes through TrainingOperationAttrs::ReplicateAttrs
PANIC("unexpected replicate in PCGOperatorAttrs path");
return Realm::Event::NO_EVENT;
},
[&](auto const &) { return spawn_task(); },
});
},
[&](LossAttrs const &) { return spawn_task(); },
[&](CopyAttrs const &) { return issue_copy(); },
[&](ReplicateAttrs const &) {
if (invocation.node_attrs.task_type.has_value() &&
invocation.node_attrs.task_type.value() == DynamicTaskType::BWD) {
return issue_replicate_bwd();
}
return issue_copy();
},
});
}

Expand Down
9 changes: 8 additions & 1 deletion lib/realm-execution/src/realm-execution/realm_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ Realm::Event
Realm::RegionInstance dst_inst,
Realm::ProfilingRequestSet const &requests,
Realm::Event wait_on,
int priority) {
int priority,
std::optional<Realm::ReductionOpID> redop_id,
bool exclusive) {
TensorShape src_piece_shape = get_piece_shape(src_shape);
TensorShape dst_piece_shape = get_piece_shape(dst_shape);
ASSERT(src_piece_shape == dst_piece_shape); // For now, assume they match
Expand All @@ -183,6 +185,11 @@ Realm::Event
size_of_datatype(src_piece_shape.data_type).int_from_positive_int()),
/*subfield_offset=*/0);

// set reduction op on dst field if provided
if (redop_id.has_value()) {
dst_field.set_redop(redop_id.value(), /*is_fold=*/false, exclusive);
}

Realm::Event result;
switch (src_piece_shape.dims.ff_ordered.num_dims()) {
#if REALM_MAX_DIM >= 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,17 @@ void per_device_op_state_init_task_body(void const *args,
result_state, ctx.get_current_device_idx())};
DeviceSpecificPtr<PerDeviceOpState> result_device_specific{
ctx.get_current_device_idx(), result_state_ptr};
spawn_per_device_op_state_init_return_task(ctx,
task_args.origin_proc,
result_device_specific,
task_args.origin_result_ptr,
Realm::Event::NO_EVENT);

// replace spawn_per_device_op_state_init_return_task with:
// NOTE: SM/TODO: direct write assumes single-node shared address space
// For multi-node, replace with UserEvent trigger pattern
*task_args.origin_result_ptr = result_device_specific;

// spawn_per_device_op_state_init_return_task(ctx,
// task_args.origin_proc,
// result_device_specific,
// task_args.origin_result_ptr,
// Realm::Event::NO_EVENT);
}

std::optional<Realm::Event> spawn_per_device_op_state_init_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "realm-execution/tasks/impl/op_task.h"
#include "realm-execution/tasks/impl/per_device_op_state_init_return_task.h"
#include "realm-execution/tasks/impl/per_device_op_state_init_task.h"
#include "realm-execution/tasks/realm_reduction.h"
#include "realm-execution/tasks/task_id_t.h"
#include "utils/exception.h"

Expand All @@ -30,9 +31,18 @@ Realm::Event register_task(Realm::Processor::Kind target_kind,
Realm::ProfilingRequestSet());
}

static void register_reductions() {
// register sum reduction ops
Realm::Runtime rt = Realm::Runtime::get_runtime();
rt.register_reduction<SumReductionFloat>(REDOP_SUM_FLOAT);
rt.register_reduction<SumReductionDouble>(REDOP_SUM_DOUBLE);
// register_reduction is synchronous — no event returned
}

Realm::Event register_all_tasks() {
std::vector<Realm::Event> pending_registrations;

register_reductions();
std::vector<task_id_t> init_task_ids = {
// Init tasks
task_id_t::BATCHNORM_INIT_TASK_ID,
Expand Down
Loading