diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 1aa06e5b03..cccf56a8be 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2298,8 +2298,9 @@ get_allele_samples(const tsk_site_t *site, tsk_size_t site_offset, } static int -norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, - tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) +norm_hap_weighted(tsk_size_t TSK_UNUSED(state_dim), const double *hap_weights, + tsk_size_t result_dim, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), + double *result, void *params) { sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; const double *weight_row; @@ -2315,8 +2316,9 @@ norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, } static int -norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights, - tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) +norm_hap_weighted_ij(tsk_size_t TSK_UNUSED(state_dim), const double *hap_weights, + tsk_size_t result_dim, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), + double *result, void *params) { sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; const double *weight_row; @@ -2341,8 +2343,9 @@ norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights, } static int -norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights), - tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) +norm_total_weighted(tsk_size_t TSK_UNUSED(state_dim), + const double *TSK_UNUSED(hap_weights), tsk_size_t result_dim, tsk_size_t n_a, + tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) { tsk_size_t k; double norm = 1 / (double) (n_a * n_b); @@ -2411,8 +2414,8 @@ static int compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off, tsk_size_t num_a_alleles, tsk_size_t num_b_alleles, tsk_size_t state_dim, - tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, - norm_func_t *norm_f, bool polarised, two_locus_work_t *restrict work, double *result) + tsk_size_t result_dim, general_stat_func_t *f, void *f_params, norm_func_t *norm_f, + bool polarised, two_locus_work_t *restrict work, double *result) { int ret = 0; // Sample sets and b sites are rows, a sites are columns @@ -2445,7 +2448,7 @@ compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, if (ret != 0) { goto out; } - ret = norm_f(result_dim, weights, num_a_alleles - is_polarised, + ret = norm_f(state_dim, weights, result_dim, num_a_alleles - is_polarised, num_b_alleles - is_polarised, norm, f_params); if (ret != 0) { goto out; @@ -2463,9 +2466,8 @@ compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, static int compute_general_two_site_stat_result(const tsk_bitset_t *state, const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off, - tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, - double *result) + tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f, void *f_params, + two_locus_work_t *restrict work, double *result) { int ret = 0; tsk_size_t k; @@ -2653,9 +2655,8 @@ static int tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *norm_f, tsk_size_t n_rows, - const tsk_id_t *row_sites, tsk_size_t n_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + void *f_params, norm_func_t *norm_f, tsk_size_t n_rows, const tsk_id_t *row_sites, + tsk_size_t n_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result) { int ret = 0; tsk_bitset_t allele_samples, allele_sample_sets; @@ -3089,9 +3090,8 @@ advance_collect_edges(iter_state *s, tsk_id_t index) static int compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, const iter_state *A_state, const iter_state *B_state, tsk_size_t state_dim, - tsk_size_t result_dim, int sign, general_stat_func_t *f, - sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, - double *result) + tsk_size_t result_dim, int sign, general_stat_func_t *f, void *f_params, + two_locus_work_t *restrict work, double *result) { int ret = 0; double a_len, b_len; @@ -3141,8 +3141,8 @@ compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, static int compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, - iter_state *r_state, general_stat_func_t *f, sample_count_stat_params_t *f_params, - tsk_size_t result_dim, tsk_size_t state_dim, double *result) + iter_state *r_state, general_stat_func_t *f, void *f_params, tsk_size_t result_dim, + tsk_size_t state_dim, double *result) { int ret = 0; tsk_id_t e, c, ec, p, *updated_nodes = NULL; @@ -3243,9 +3243,9 @@ static int tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *TSK_UNUSED(norm_f), - tsk_size_t n_rows, const double *row_positions, tsk_size_t n_cols, - const double *col_positions, tsk_flags_t TSK_UNUSED(options), double *result) + void *f_params, norm_func_t *TSK_UNUSED(norm_f), tsk_size_t n_rows, + const double *row_positions, tsk_size_t n_cols, const double *col_positions, + tsk_flags_t TSK_UNUSED(options), double *result) { int ret = 0; int r, c; @@ -3385,10 +3385,10 @@ check_sample_set_dups(tsk_size_t num_sample_sets, const tsk_size_t *sample_set_s } int -tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, - norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, +tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result) { @@ -3398,10 +3398,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); tsk_size_t state_dim = num_sample_sets; - sample_count_stat_params_t f_params = { .sample_sets = sample_sets, - .num_sample_sets = num_sample_sets, - .sample_set_sizes = sample_set_sizes, - .set_indexes = set_indexes }; // We do not support two-locus node stats if (!!(options & TSK_STAT_NODE)) { @@ -3441,7 +3437,7 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl goto out; } ret = tsk_treeseq_two_site_count_stat(self, state_dim, num_sample_sets, - sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_sites, out_cols, col_sites, options, result); } else if (stat_branch) { ret = check_positions( @@ -3455,13 +3451,30 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl goto out; } ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, - sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_positions, out_cols, col_positions, options, result); } out: return ret; } +int +tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, + norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result) +{ + sample_count_stat_params_t f_params = { .sample_sets = sample_sets, + .num_sample_sets = num_sample_sets, + .sample_set_sizes = sample_set_sizes, + .set_indexes = set_indexes }; + return tsk_treeseq_two_locus_count_general_stat(self, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + row_sites, row_positions, out_cols, col_sites, col_positions, options, result); +} + /*********************************** * Allele frequency spectrum ***********************************/ diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 84480ed96e..2bf1a26cc9 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1036,8 +1036,8 @@ int tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t K, const doub tsk_size_t M, general_stat_func_t *f, void *f_params, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); -typedef int norm_func_t(tsk_size_t result_dim, const double *hap_weights, tsk_size_t n_a, - tsk_size_t n_b, double *result, void *params); +typedef int norm_func_t(tsk_size_t state_dim, const double *hap_weights, + tsk_size_t result_dim, tsk_size_t n_a, tsk_size_t n_b, double *result, void *params); int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, @@ -1120,6 +1120,13 @@ typedef int general_sample_stat_method(const tsk_treeseq_t *self, const tsk_id_t *sample_sets, tsk_size_t num_indexes, const tsk_id_t *indexes, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result); + typedef int two_locus_count_stat_method(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, const tsk_id_t *row_sites, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 0e0c1c5ed5..c772270de6 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7946,6 +7946,290 @@ parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) return array; } +typedef struct { + PyArrayObject *sample_set_sizes; + PyObject *summary_func; + PyObject *norm_func; +} two_locus_general_stat_params; + +static int +general_two_locus_norm_func(tsk_size_t K, const double *X, tsk_size_t result_dim, + tsk_size_t n_a, tsk_size_t n_b, double *Y, void *params) +{ + int ret = TSK_PYTHON_CALLBACK_ERROR; + PyObject *arglist = NULL; + PyObject *result = NULL; + PyArrayObject *n_a_scalar = NULL; + PyArrayObject *n_b_scalar = NULL; + PyArrayObject *X_array = NULL; + PyArrayObject *Y_array = NULL; + two_locus_general_stat_params *tl_params = params; + PyObject *summary_func = tl_params->norm_func; + PyArrayObject *ss_sizes = tl_params->sample_set_sizes; + npy_intp X_dims[2] = { K, 3 }; + + // Create a read only view of X as a numpy array + X_array = (PyArrayObject *) PyArray_SimpleNewFromData( + 2, X_dims, NPY_FLOAT64, (void *) X); + if (X_array == NULL) { + goto out; + } + PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); + // Transpose into column arrays, so that we can easily decompose the results + X_array = (PyArrayObject *) PyArray_Transpose(X_array, NULL); + if (X_array == NULL) { + goto out; + } + n_a_scalar + = (PyArrayObject *) PyArray_Scalar(&n_a, PyArray_DescrFromType(NPY_INT64), NULL); + if (n_a_scalar == NULL) { + goto out; + } + n_b_scalar + = (PyArrayObject *) PyArray_Scalar(&n_b, PyArray_DescrFromType(NPY_INT64), NULL); + if (n_b_scalar == NULL) { + goto out; + } + arglist = Py_BuildValue("OOOO", X_array, ss_sizes, n_a_scalar, n_b_scalar); + if (arglist == NULL) { + goto out; + } + result = PyObject_CallObject(summary_func, arglist); + if (result == NULL) { + goto out; + } + Y_array = (PyArrayObject *) PyArray_FromAny( + result, PyArray_DescrFromType(NPY_FLOAT64), 0, 0, NPY_ARRAY_IN_ARRAY, NULL); + if (Y_array == NULL) { + goto out; + } + if (PyArray_NDIM(Y_array) != 1) { + PyErr_Format(PyExc_ValueError, + "Array returned by norm function callback is %d dimensional; " + "must be 1D", + (int) PyArray_NDIM(Y_array)); + goto out; + } + if (PyArray_DIM(Y_array, 0) != (npy_intp) result_dim) { + PyErr_Format(PyExc_ValueError, + "Array returned by norm function callback is of length %d; must be %d", + PyArray_DIM(Y_array, 0), result_dim); + goto out; + } + /* Copy the contents of the return Y array into Y */ + memcpy(Y, PyArray_DATA(Y_array), result_dim * sizeof(*Y)); + ret = 0; +out: + Py_XDECREF(X_array); + Py_XDECREF(arglist); + Py_XDECREF(result); + Py_XDECREF(Y_array); + Py_XDECREF(n_a_scalar); + Py_XDECREF(n_b_scalar); + return ret; +} + +static int +general_two_locus_count_stat_func( + tsk_size_t K, const double *X, tsk_size_t result_dim, double *Y, void *params) +{ + int ret = TSK_PYTHON_CALLBACK_ERROR; + PyObject *arglist = NULL; + PyObject *result = NULL; + PyArrayObject *X_array = NULL; + PyArrayObject *Y_array = NULL; + two_locus_general_stat_params *tl_params = params; + PyObject *summary_func = tl_params->summary_func; + PyArrayObject *ss_sizes = tl_params->sample_set_sizes; + npy_intp X_dims[2] = { K, 3 }; + + // Create a read only view of X as a numpy array + X_array = (PyArrayObject *) PyArray_SimpleNewFromData( + 2, X_dims, NPY_FLOAT64, (void *) X); + if (X_array == NULL) { + goto out; + } + PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); + // Transpose into column arrays, so that we can easily decompose the results + // For example: pAB, pAb, paB = X / n + // which works with K>1. In addition, the data is not reordered, meaning + // that the data is still oriented where samples are rows, meaning that + // we'll preserve data locality in ops over samples. + X_array = (PyArrayObject *) PyArray_Transpose(X_array, NULL); + if (X_array == NULL) { + goto out; + } + arglist = Py_BuildValue("OO", X_array, ss_sizes); + if (arglist == NULL) { + goto out; + } + result = PyObject_CallObject(summary_func, arglist); + if (result == NULL) { + goto out; + } + Y_array = (PyArrayObject *) PyArray_FromAny( + result, PyArray_DescrFromType(NPY_FLOAT64), 0, 0, NPY_ARRAY_IN_ARRAY, NULL); + if (Y_array == NULL) { + goto out; + } + if (PyArray_NDIM(Y_array) != 1) { + PyErr_Format(PyExc_ValueError, + "Array returned by summary function callback is %d dimensional; " + "must be 1D", + (int) PyArray_NDIM(Y_array)); + goto out; + } + if (PyArray_DIM(Y_array, 0) != (npy_intp) result_dim) { + PyErr_Format(PyExc_ValueError, + "Array returned by summary function callback is of length %d; must be %d", + PyArray_DIM(Y_array, 0), result_dim); + goto out; + } + /* Copy the contents of the return Y array into Y */ + memcpy(Y, PyArray_DATA(Y_array), result_dim * sizeof(*Y)); + ret = 0; +out: + Py_XDECREF(X_array); + Py_XDECREF(arglist); + Py_XDECREF(result); + Py_XDECREF(Y_array); + return ret; +} + +static PyObject * +TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "sample_set_sizes", "sample_sets", "summary_func", + "norm_func", "output_dim", "polarised", "row_sites", "col_sites", + "row_positions", "column_positions", "mode", NULL }; + two_locus_general_stat_params *params; + PyObject *summary_func = NULL; + PyObject *norm_func = NULL; + unsigned int output_dim; + PyObject *sample_set_sizes = NULL; + PyObject *sample_sets = NULL; + PyObject *row_sites = NULL; + PyObject *col_sites = NULL; + PyObject *row_positions = NULL; + PyObject *col_positions = NULL; + char *mode = NULL; + PyArrayObject *sample_set_sizes_array = NULL; + PyArrayObject *sample_sets_array = NULL; + PyArrayObject *row_sites_array = NULL; + PyArrayObject *col_sites_array = NULL; + PyArrayObject *row_positions_array = NULL; + PyArrayObject *col_positions_array = NULL; + PyArrayObject *result_matrix = NULL; + tsk_id_t *row_sites_parsed = NULL; + tsk_id_t *col_sites_parsed = NULL; + double *row_positions_parsed = NULL; + double *col_positions_parsed = NULL; + npy_intp result_dim[3] = { 0, 0, 0 }; + tsk_size_t num_sample_sets; + tsk_flags_t options = 0; + int polarised = 0; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOIiOOOO|s", kwlist, + &sample_set_sizes, &sample_sets, &summary_func, &norm_func, &output_dim, + &polarised, &row_sites, &col_sites, &row_positions, &col_positions, &mode)) { + Py_XINCREF(summary_func); + Py_XINCREF(norm_func); + goto out; + } + Py_INCREF(summary_func); + Py_INCREF(norm_func); + if (!PyCallable_Check(summary_func)) { + PyErr_SetString(PyExc_TypeError, "summary_func must be callable"); + goto out; + } + if (!PyCallable_Check(norm_func)) { + PyErr_SetString(PyExc_TypeError, "norm_func must be callable"); + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (polarised) { + options |= TSK_STAT_POLARISED; + } + + if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets, + &sample_sets_array, &num_sample_sets) + != 0) { + goto out; + } + PyArray_CLEARFLAGS(sample_set_sizes_array, NPY_ARRAY_WRITEABLE); + + if (options & TSK_STAT_SITE) { + if (row_positions != Py_None || col_positions != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify positions in site mode"); + goto out; + } + row_sites_array = parse_sites(self, row_sites, &(result_dim[0])); + col_sites_array = parse_sites(self, col_sites, &(result_dim[1])); + if (row_sites_array == NULL || col_sites_array == NULL) { + goto out; + } + row_sites_parsed = PyArray_DATA(row_sites_array); + col_sites_parsed = PyArray_DATA(col_sites_array); + } else if (options & TSK_STAT_BRANCH) { + if (row_sites != Py_None || col_sites != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify sites in branch mode"); + goto out; + } + row_positions_array = parse_positions(self, row_positions, &(result_dim[0])); + col_positions_array = parse_positions(self, col_positions, &(result_dim[1])); + if (col_positions_array == NULL || row_positions_array == NULL) { + goto out; + } + row_positions_parsed = PyArray_DATA(row_positions_array); + col_positions_parsed = PyArray_DATA(col_positions_array); + } + + result_dim[2] = output_dim; + result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_dim, NPY_FLOAT64, 0); + if (result_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + + params = &(two_locus_general_stat_params) { + .sample_set_sizes = sample_set_sizes_array, + .summary_func = summary_func, + .norm_func = norm_func, + }; + err = tsk_treeseq_two_locus_count_general_stat(self->tree_sequence, num_sample_sets, + PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), + output_dim, general_two_locus_count_stat_func, params, + general_two_locus_norm_func, result_dim[0], row_sites_parsed, + row_positions_parsed, result_dim[1], col_sites_parsed, col_positions_parsed, + options, PyArray_DATA(result_matrix)); + + if (err == TSK_PYTHON_CALLBACK_ERROR) { + goto out; + } else if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_matrix; + result_matrix = NULL; +out: + Py_XDECREF(summary_func); + Py_XDECREF(row_sites_array); + Py_XDECREF(col_sites_array); + Py_XDECREF(row_positions_array); + Py_XDECREF(col_positions_array); + Py_XDECREF(sample_sets_array); + Py_XDECREF(sample_set_sizes_array); + Py_XDECREF(result_matrix); + return ret; +} + static PyObject * TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, two_locus_count_stat_method *method) @@ -8831,6 +9115,11 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_general_stat, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Runs the general stats algorithm for a given summary function." }, + { .ml_name = "two_locus_count_stat", + .ml_meth = (PyCFunction) TreeSequence_two_locus_count_stat, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc + = "Runs the general two locus stats algorithm for a given summary function." }, { .ml_name = "diversity", .ml_meth = (PyCFunction) TreeSequence_diversity, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 4d6e47ddcc..953f542f30 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2398,3 +2398,256 @@ def test_multipopulation_r2_varying_unequal_set_sizes(genotypes, sample_sets, ex norm_hap_weighted_ij(1, state, max(a) + 1, max(b) + 1, norm[i, j], params) np.testing.assert_allclose((result * norm).sum(), expected) + + +class GeneralStatFuncs: + """ + functions take X, n as parameters where + + X: shape=(3, #ss) + sample sets + count AB [[ ] + count Ab [ ] + count aB [ ]] + + n: shape=(#ss, ) + [ ] + """ + + @staticmethod + def D(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + @staticmethod + def D2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return (pAB - (pA * pB)) ** 2 + + @staticmethod + def r2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = pA * pB * (1 - pA) * (1 - pB) + with suppress_overflow_div0_warning(): + return D**2 / denom + + @staticmethod + def r(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = pA * pB * (1 - pA) * (1 - pB) + with suppress_overflow_div0_warning(): + return D / np.sqrt(denom) + + @staticmethod + def D_prime(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = np.vstack( + [ + np.min([pA * (1 - pB), (1 - pA) * pB], axis=0), + np.min([pA * pB, (1 - pA) * (1 - pB)], axis=0), + ] + ) + with suppress_overflow_div0_warning(): + return D / denom[(D < 0).astype(int), range(len(D))] + + @staticmethod + def Dz(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + return D * (1 - 2 * pA) * (1 - 2 * pB) + + @staticmethod + def pi2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pA * (1 - pA) * pB * (1 - pB) + + @staticmethod + def D2_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + ((aB**2) * (Ab - 1) * Ab) + + ((ab - 1) * ab * (AB - 1) * AB) + - (aB * Ab * (Ab + (2 * ab * AB) - 1)) + ) + + @staticmethod + def Dz_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + (((AB * ab) - (Ab * aB)) * (aB + ab - AB - Ab) * (Ab + ab - AB - aB)) + - ((AB * ab) * (AB + ab - Ab - aB - 2)) + - ((Ab * aB) * (Ab + aB - AB - ab - 2)) + ) + + @staticmethod + def pi2_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + ((AB + Ab) * (aB + ab) * (AB + aB) * (Ab + ab)) + - ((AB * ab) * (AB + ab + (3 * Ab) + (3 * aB) - 1)) + - ((Ab * aB) * (Ab + aB + (3 * AB) + (3 * ab) - 1)) + ) + + @staticmethod + def r2_ij(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D2_ij = np.prod(pAB - (pA * pB)) + denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB))) + with suppress_overflow_div0_warning(): + return np.expand_dims(D2_ij / denom, axis=0) + + @staticmethod + def D2_ij(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return np.expand_dims(np.prod(pAB - (pA * pB)), axis=0) + + @staticmethod + def D2_ij_unbiased(X, n): + """ + NB: the two sample sets must be disjoint + we have no way for testing equality + """ + AB, Ab, aB = X + ab = n - X.sum(0) + return np.expand_dims( + (Ab[0] * aB[0] - AB[0] * ab[0]) + * (Ab[1] * aB[1] - AB[1] * ab[1]) + / n[0] + / (n[0] - 1) + / n[1] + / (n[1] - 1), + axis=0, + ) + + +@pytest.mark.parametrize( + "ts,stat", + [ + ( + ts := [ + p for p in get_example_tree_sequences() if p.id == "n=100_m=32_rho=0.5" + ][0].values[0], + "D", + ), + (ts, "D2"), + (ts, "r2"), + (ts, "r"), + (ts, "D_prime"), + (ts, "Dz"), + (ts, "pi2"), + (ts, "D2_unbiased"), + (ts, "Dz_unbiased"), + (ts, "pi2_unbiased"), + ], +) +def test_general_two_locus_site_stat(ts, stat): + sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] + ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2) + ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) + np.testing.assert_array_almost_equal(ldg, ld) + + +@pytest.mark.parametrize( + "ts,stat", + [ + ( + ts := [ + p for p in get_example_tree_sequences() if p.id == "n=100_m=32_rho=0.5" + ][0].values[0], + "r2_ij", + ), + (ts, "D2_ij"), + (ts, "D2_ij_unbiased"), + ], +) +def test_general_two_locus_two_way_site_stat(ts, stat): + sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] + ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 1) + ld = ts.ld_matrix( + sample_sets=sample_sets, stat=stat.replace("_ij", ""), indexes=(0, 1) + ) + np.testing.assert_array_almost_equal(ldg, ld) + + +@pytest.mark.parametrize( + "stat", + [ + "D", + "D2", + "r2", + "r", + "D_prime", + "Dz", + "pi2", + "D2_unbiased", + "Dz_unbiased", + "pi2_unbiased", + ], +) +def test_general_one_way_two_locus_stat_multiallelic(stat): + (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values + func = getattr(GeneralStatFuncs, stat) + if stat == "r2": + result = ts.two_locus_count_stat( + [ts.samples()], func, 1, lambda X, n, nA, nB: X[0] / n + ) + elif stat in {"D", "r", "D_prime"}: + result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True) + else: + # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) + result = ts.two_locus_count_stat([ts.samples()], func, 1) + np.testing.assert_array_almost_equal(ts.ld_matrix(stat=stat), result) + + +@pytest.mark.parametrize( + "stat", + [ + "r2_ij", + "D2_ij", + "D2_ij_unbiased", + ], +) +def test_general_two_way_two_locus_stat_multiallelic(stat): + (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values + func = getattr(GeneralStatFuncs, stat) + if stat == "r2_ij": + + def norm_f(X, n, nA, nB): + return np.expand_dims(X[0].sum() / n.sum(), axis=0) + + result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1, norm_f) + else: + # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) + result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1) + np.testing.assert_array_almost_equal( + ts.ld_matrix( + stat=stat.replace("_ij", ""), + indexes=(0, 1), + sample_sets=[ts.samples(), ts.samples()], + ), + result, + ) diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index 15f9967f3f..625d8f9bcb 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -1987,6 +1987,166 @@ def test_ld_matrix_multipop(self, stat_method_name): with pytest.raises(_tskit.LibraryError, match="TSK_ERR_UNSUPPORTED_STAT_MODE"): stat_method(ss_sizes, ss, indexes, col_sites, row_sites, None, None, "node") + def test_two_locus_count_stat(self): + ts = self.get_example_tree_sequence(10) + ss = ts.get_samples() # sample sets + ss_sizes = np.array([len(ss)], dtype=np.uint32) + row_sites = np.arange(ts.get_num_sites(), dtype=np.int32) + col_sites = row_sites + row_pos = ts.get_breakpoints()[:-1] + col_pos = row_pos + row_sites_list = list(range(ts.get_num_sites())) + col_sites_list = row_sites_list + row_pos_list = list(map(float, ts.get_breakpoints()[:-1])) + col_pos_list = row_pos_list + + def stat_func(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + def norm_func(X, n, nA, nB): + return np.expand_dims(X[0].sum() / n.sum(), axis=0) + + method = ts.two_locus_count_stat + + site_args = row_sites, col_sites, None, None, "site" + branch_args = None, None, row_pos, col_pos, "branch" + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args) + assert a.shape == (10, 10, 1) + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_args) + assert a.shape == (2, 2, 1) + site_list_args = row_sites_list, col_sites_list, None, None, "site" + branch_list_args = None, None, row_pos_list, col_pos_list, "branch" + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_list_args) + assert a.shape == (10, 10, 1) + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_list_args) + assert a.shape == (2, 2, 1) + # CPython API errors + with pytest.raises(ValueError, match="Sum of sample_set_sizes"): + bad_ss = np.array([], dtype=np.int32) + method(ss_sizes, bad_ss, stat_func, norm_func, 1, True, *site_args) + with pytest.raises(TypeError, match="cast array data"): + bad_ss = np.array(ts.get_samples(), dtype=np.uint32) + method(ss_sizes, bad_ss, stat_func, norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="Unrecognised stats mode"): + bad_args = row_sites, col_sites, None, None, "bla" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_args) + with pytest.raises(TypeError, match="at most"): + method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args, "extraarg") + with pytest.raises(ValueError, match="invalid literal"): + bad_sites = ["abadsite", 0, 3, 2] + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [None, 0, 3, 2] + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [{}, 0, 3, 2] + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError, match="Cannot cast array data"): + bad_sites = np.array([0, 1, 2], dtype=np.uint32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(ValueError, match="invalid literal"): + bad_sites = ["abadsite", 0, 3, 2] + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [None, 0, 3, 2] + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [{}, 0, 3, 2] + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError, match="Cannot cast array data"): + bad_sites = np.array([0, 1, 2], dtype=np.uint32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(ValueError): + bad_pos = ["abadpos", 0.1, 0.2, 2.0] + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(TypeError): + bad_pos = [{}, 0.1, 0.2, 2.0] + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(ValueError): + bad_pos = ["abadpos", 0, 3, 2] + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(TypeError): + bad_pos = [{}, 0, 3, 2] + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(ValueError, match="Cannot specify positions in site mode"): + bad_site_args = None, None, row_pos, col_pos, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(ValueError, match="Cannot specify sites in branch mode"): + bad_branch_args = row_sites, col_sites, None, None, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(TypeError, match="summary_func must be callable"): + method(ss_sizes, ss, "uncallable", norm_func, 1, True, *site_args) + with pytest.raises(TypeError, match="norm_func must be callable"): + method(ss_sizes, ss, stat_func, "uncallable", 1, True, *site_args) + with pytest.raises(ValueError, match="summary function.*must be 1D"): + method(ss_sizes, ss, lambda a, b: 1, norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="length 2; must be 1"): + method(ss_sizes, ss, lambda a, b: [1, 2], norm_func, 1, True, *site_args) + # C API errors + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): + bad_sites = np.array([1, 0, 2], dtype=np.int32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): + bad_sites = np.array([1, 0, 2], dtype=np.int32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_SITES"): + bad_sites = np.array([1, 1, 2], dtype=np.int32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_SITES"): + bad_sites = np.array([1, 1, 2], dtype=np.int32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_SITE_OUT_OF_BOUNDS"): + bad_sites = np.array([-1, 0, 2], dtype=np.int32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_SITE_OUT_OF_BOUNDS"): + bad_sites = np.array([-1, 0, 2], dtype=np.int32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_POSITIONS"): + bad_pos = np.array([0.7, 0, 0.8], dtype=np.float64) + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_POSITIONS"): + bad_pos = np.array([0.7, 0, 0.8], dtype=np.float64) + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_POSITIONS"): + bad_pos = np.array([0.7, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_POSITIONS"): + bad_pos = np.array([0.7, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_POSITION_OUT_OF_BOUNDS"): + bad_pos = np.array([-0.1, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_POSITION_OUT_OF_BOUNDS"): + bad_pos = np.array([-0.1, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + def test_kc_distance_errors(self): ts1 = self.get_example_tree_sequence(10) with pytest.raises(TypeError): diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 45d2da59e0..6fc2bc0b4d 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8249,19 +8249,7 @@ def parse_positions(self, positions): ) return row_positions, col_positions - def __two_locus_sample_set_stat( - self, - ll_method, - sample_sets, - sites=None, - positions=None, - mode=None, - ): - if sample_sets is None: - sample_sets = self.samples() - row_sites, col_sites = self.parse_sites(sites) - row_positions, col_positions = self.parse_positions(positions) - + def __convert_sample_sets(self, sample_sets): # First try to convert to a 1D numpy array. If we succeed, then we strip off # the corresponding dimension from the output. drop_dimension = False @@ -8283,7 +8271,23 @@ def __two_locus_sample_set_stat( raise ValueError("Sample sets must contain at least one element") flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) + return drop_dimension, flattened, sample_set_sizes + def __two_locus_sample_set_stat( + self, + ll_method, + sample_sets, + sites=None, + positions=None, + mode=None, + ): + if sample_sets is None: + sample_sets = self.samples() + row_sites, col_sites = self.parse_sites(sites) + row_positions, col_positions = self.parse_positions(positions) + drop_dimension, flattened, sample_set_sizes = self.__convert_sample_sets( + sample_sets + ) result = ll_method( sample_set_sizes, flattened, @@ -10927,6 +10931,39 @@ def impute_unknown_mutations_time( mutations_time[unknown] = self.nodes_time[self.mutations_node[unknown]] return mutations_time + def two_locus_count_stat( + self, + sample_sets, + f, + result_dim, + norm_f=lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0), + polarised=False, + sites=None, + positions=None, + mode="site", + ): + row_sites, col_sites = self.parse_sites(sites) + row_positions, col_positions = self.parse_positions(positions) + _, sample_sets, sample_set_sizes = self.__convert_sample_sets(sample_sets) + result = self._ll_tree_sequence.two_locus_count_stat( + sample_set_sizes, + sample_sets, + f, + norm_f, + result_dim, + polarised, + row_sites, + col_sites, + row_positions, + col_positions, + mode, + ) + if result_dim == 1: # drop dimension + return result.reshape(result.shape[:2]) + # Orient the data so that the first dimension is the sample set so that + # we get one LD matrix per sample set. + return result.swapaxes(0, 2).swapaxes(1, 2) + def ld_matrix( self, sample_sets=None,