Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 46 additions & 33 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -2298,8 +2298,9 @@ get_allele_samples(const tsk_site_t *site, tsk_size_t site_offset,
}

static int
norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights,
tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params)
norm_hap_weighted(tsk_size_t TSK_UNUSED(state_dim), const double *hap_weights,
tsk_size_t result_dim, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b),
double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
const double *weight_row;
Expand All @@ -2315,8 +2316,9 @@ norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights,
}

static int
norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights,
tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params)
norm_hap_weighted_ij(tsk_size_t TSK_UNUSED(state_dim), const double *hap_weights,
tsk_size_t result_dim, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b),
double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
const double *weight_row;
Expand All @@ -2341,8 +2343,9 @@ norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights,
}

static int
norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights),
tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params))
norm_total_weighted(tsk_size_t TSK_UNUSED(state_dim),
const double *TSK_UNUSED(hap_weights), tsk_size_t result_dim, tsk_size_t n_a,
tsk_size_t n_b, double *result, void *TSK_UNUSED(params))
{
tsk_size_t k;
double norm = 1 / (double) (n_a * n_b);
Expand Down Expand Up @@ -2411,8 +2414,8 @@ static int
compute_general_normed_two_site_stat_result(const tsk_bitset_t *state,
const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off,
tsk_size_t num_a_alleles, tsk_size_t num_b_alleles, tsk_size_t state_dim,
tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params,
norm_func_t *norm_f, bool polarised, two_locus_work_t *restrict work, double *result)
tsk_size_t result_dim, general_stat_func_t *f, void *f_params, norm_func_t *norm_f,
bool polarised, two_locus_work_t *restrict work, double *result)
{
int ret = 0;
// Sample sets and b sites are rows, a sites are columns
Expand Down Expand Up @@ -2445,7 +2448,7 @@ compute_general_normed_two_site_stat_result(const tsk_bitset_t *state,
if (ret != 0) {
goto out;
}
ret = norm_f(result_dim, weights, num_a_alleles - is_polarised,
ret = norm_f(state_dim, weights, result_dim, num_a_alleles - is_polarised,
num_b_alleles - is_polarised, norm, f_params);
if (ret != 0) {
goto out;
Expand All @@ -2463,9 +2466,8 @@ compute_general_normed_two_site_stat_result(const tsk_bitset_t *state,
static int
compute_general_two_site_stat_result(const tsk_bitset_t *state,
const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off,
tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f,
sample_count_stat_params_t *f_params, two_locus_work_t *restrict work,
double *result)
tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f, void *f_params,
two_locus_work_t *restrict work, double *result)
{
int ret = 0;
tsk_size_t k;
Expand Down Expand Up @@ -2653,9 +2655,8 @@ static int
tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f,
sample_count_stat_params_t *f_params, norm_func_t *norm_f, tsk_size_t n_rows,
const tsk_id_t *row_sites, tsk_size_t n_cols, const tsk_id_t *col_sites,
tsk_flags_t options, double *result)
void *f_params, norm_func_t *norm_f, tsk_size_t n_rows, const tsk_id_t *row_sites,
tsk_size_t n_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result)
{
int ret = 0;
tsk_bitset_t allele_samples, allele_sample_sets;
Expand Down Expand Up @@ -3089,9 +3090,8 @@ advance_collect_edges(iter_state *s, tsk_id_t index)
static int
compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c,
const iter_state *A_state, const iter_state *B_state, tsk_size_t state_dim,
tsk_size_t result_dim, int sign, general_stat_func_t *f,
sample_count_stat_params_t *f_params, two_locus_work_t *restrict work,
double *result)
tsk_size_t result_dim, int sign, general_stat_func_t *f, void *f_params,
two_locus_work_t *restrict work, double *result)
{
int ret = 0;
double a_len, b_len;
Expand Down Expand Up @@ -3141,8 +3141,8 @@ compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c,

static int
compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state,
iter_state *r_state, general_stat_func_t *f, sample_count_stat_params_t *f_params,
tsk_size_t result_dim, tsk_size_t state_dim, double *result)
iter_state *r_state, general_stat_func_t *f, void *f_params, tsk_size_t result_dim,
tsk_size_t state_dim, double *result)
{
int ret = 0;
tsk_id_t e, c, ec, p, *updated_nodes = NULL;
Expand Down Expand Up @@ -3243,9 +3243,9 @@ static int
tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f,
sample_count_stat_params_t *f_params, norm_func_t *TSK_UNUSED(norm_f),
tsk_size_t n_rows, const double *row_positions, tsk_size_t n_cols,
const double *col_positions, tsk_flags_t TSK_UNUSED(options), double *result)
void *f_params, norm_func_t *TSK_UNUSED(norm_f), tsk_size_t n_rows,
const double *row_positions, tsk_size_t n_cols, const double *col_positions,
tsk_flags_t TSK_UNUSED(options), double *result)
{
int ret = 0;
int r, c;
Expand Down Expand Up @@ -3385,10 +3385,10 @@ check_sample_set_dups(tsk_size_t num_sample_sets, const tsk_size_t *sample_set_s
}

int
tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f,
norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites,
tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f,
void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites,
const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites,
const double *col_positions, tsk_flags_t options, double *result)
{
Expand All @@ -3398,10 +3398,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl
bool stat_site = !!(options & TSK_STAT_SITE);
bool stat_branch = !!(options & TSK_STAT_BRANCH);
tsk_size_t state_dim = num_sample_sets;
sample_count_stat_params_t f_params = { .sample_sets = sample_sets,
.num_sample_sets = num_sample_sets,
.sample_set_sizes = sample_set_sizes,
.set_indexes = set_indexes };

// We do not support two-locus node stats
if (!!(options & TSK_STAT_NODE)) {
Expand Down Expand Up @@ -3441,7 +3437,7 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl
goto out;
}
ret = tsk_treeseq_two_site_count_stat(self, state_dim, num_sample_sets,
sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows,
sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows,
row_sites, out_cols, col_sites, options, result);
} else if (stat_branch) {
ret = check_positions(
Expand All @@ -3455,13 +3451,30 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl
goto out;
}
ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets,
sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows,
sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows,
row_positions, out_cols, col_positions, options, result);
}
out:
return ret;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function now serves as an inner wrapper. The the general stat accepts the summary function params so that the CPython code can pass them directly. All of the specialized stats functions call this function.

int
tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f,
norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites,
const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites,
const double *col_positions, tsk_flags_t options, double *result)
{
sample_count_stat_params_t f_params = { .sample_sets = sample_sets,
.num_sample_sets = num_sample_sets,
.sample_set_sizes = sample_set_sizes,
.set_indexes = set_indexes };
return tsk_treeseq_two_locus_count_general_stat(self, num_sample_sets,
sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows,
row_sites, row_positions, out_cols, col_sites, col_positions, options, result);
}

/***********************************
* Allele frequency spectrum
***********************************/
Expand Down
11 changes: 9 additions & 2 deletions c/tskit/trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -1036,8 +1036,8 @@ int tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t K, const doub
tsk_size_t M, general_stat_func_t *f, void *f_params, tsk_size_t num_windows,
const double *windows, tsk_flags_t options, double *result);

typedef int norm_func_t(tsk_size_t result_dim, const double *hap_weights, tsk_size_t n_a,
tsk_size_t n_b, double *result, void *params);
typedef int norm_func_t(tsk_size_t state_dim, const double *hap_weights,
tsk_size_t result_dim, tsk_size_t n_a, tsk_size_t n_b, double *result, void *params);

int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
Expand Down Expand Up @@ -1120,6 +1120,13 @@ typedef int general_sample_stat_method(const tsk_treeseq_t *self,
const tsk_id_t *sample_sets, tsk_size_t num_indexes, const tsk_id_t *indexes,
tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result);

int tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f,
void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites,
const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites,
const double *col_positions, tsk_flags_t options, double *result);

typedef int two_locus_count_stat_method(const tsk_treeseq_t *self,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
const tsk_id_t *sample_sets, tsk_size_t num_rows, const tsk_id_t *row_sites,
Expand Down
Loading