Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions roofit/hs3/src/JSONFactories_HistFactory.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,36 @@ getOrCreateConstraint(RooJSONFactoryWSTool &tool, const JSONNode &mod, RooRealVa
"'");
}
}
double poissonTau(RooPoisson const &constraint, RooAbsArg const &gamma)
{
auto const *mean = dynamic_cast<RooProduct const *>(&constraint.getMean());
if (!mean) {
RooJSONFactoryWSTool::error(
"Poisson gamma constraint mean is not a RooProduct: " + std::string(constraint.GetName()));
}

for (RooAbsArg *arg : mean->servers()) {
if (arg == &gamma) {
continue;
}

if (auto const *tau = dynamic_cast<RooConstVar const *>(arg)) {
return tau->getVal();
}

// Imported workspaces can sometimes represent
// constants as constant RooRealVars.
if (auto const *real = dynamic_cast<RooAbsReal const *>(arg)) {
if (real->isConstant() || endsWith(std::string(real->GetName()), "_tau")) {
return real->getVal();
}
}
}

RooJSONFactoryWSTool::error(
"Could not find tau component in Poisson gamma constraint mean: " + std::string(constraint.GetName()));
return std::numeric_limits<double>::quiet_NaN();
}

bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet const &varlist,
RooAbsArg const *mcStatObject, const std::string &fprefix, const JSONNode &p,
Expand Down Expand Up @@ -334,6 +364,7 @@ bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet con
// this is dealt with at a different place, ignore it for now
} else if (modtype == "normfactor") {
RooRealVar &constrParam = getOrCreate<RooRealVar>(ws, sysname, 1., -3, 5);
constrParam.setError(0.0);
normElems.add(constrParam);
if (mod.has_child("constraint_name") || mod.has_child("constraint_type")) {
// for norm factors, constraints are optional
Expand Down Expand Up @@ -1060,7 +1091,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons
if (constraint) {
sample.barlowBeestonLightConstraintType = constraint->IsA();
if (RooPoisson *constraint_p = dynamic_cast<RooPoisson *>(constraint)) {
double erel = 1. / std::sqrt(constraint_p->getX().getVal());
double erel = 1. / std::sqrt(poissonTau(*constraint_p, *g));
channel.rel_errors[idx] = erel;
} else if (RooGaussian *constraint_g = dynamic_cast<RooGaussian *>(constraint)) {
double erel = constraint_g->getSigma().getVal() / constraint_g->getMean().getVal();
Expand Down Expand Up @@ -1094,7 +1125,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons
if (!constraint) {
sys.constraints.push_back(0.0);
} else if (auto constraint_p = dynamic_cast<RooPoisson *>(constraint)) {
sys.constraints.push_back(1. / std::sqrt(constraint_p->getX().getVal()));
sys.constraints.push_back(1. / std::sqrt(poissonTau(*constraint_p, *g)));
if (!sys.constraint) {
sys.constraintType = RooPoisson::Class();
}
Expand Down
30 changes: 30 additions & 0 deletions roofit/hs3/src/JSONFactories_RooFitCore.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <RooBinSamplingPdf.h>
#include <RooBinWidthFunction.h>
#include <RooCategory.h>
#include <RooConstVar.h>
#include <RooDataHist.h>
#include <RooDecay.h>
#include <RooDerivative.h>
Expand All @@ -29,6 +30,7 @@
#include <RooFitHS3/JSONIO.h>
#include <RooFormulaVar.h>
#include <RooGenericPdf.h>
#include <RooGaussian.h>
#include <RooHistFunc.h>
#include <RooHistPdf.h>
#include <RooLegacyExpPoly.h>
Expand Down Expand Up @@ -878,6 +880,32 @@ class RooPoissonStreamer : public RooFit::JSONIO::Exporter {
}
};

class RooGaussianStreamer : public RooFit::JSONIO::Exporter {
public:
std::string const &key() const override;
bool autoExportDependants() const override { return false; }
bool exportObject(RooJSONFactoryWSTool *tool, const RooAbsArg *func, JSONNode &elem) const override
{
auto *pdf = static_cast<const RooGaussian *>(func);
elem["type"] << key();
writeArg(tool, elem["x"], pdf->getX());
writeArg(tool, elem["mean"], pdf->getMean());
writeArg(tool, elem["sigma"], pdf->getSigma());
return true;
}

private:
static void writeArg(RooJSONFactoryWSTool *tool, JSONNode &node, RooAbsReal const &arg)
{
if (auto const *constant = dynamic_cast<RooConstVar const *>(&arg)) {
node << constant->getVal();
} else {
node << arg.GetName();
tool->queueExport(arg);
}
}
};

class RooDecayStreamer : public RooFit::JSONIO::Exporter {
public:
std::string const &key() const override;
Expand Down Expand Up @@ -1171,6 +1199,7 @@ DEFINE_EXPORTER_KEY(RooHistPdfStreamer, "histogram_dist");
DEFINE_EXPORTER_KEY(RooLogNormalStreamer, "lognormal_dist");
DEFINE_EXPORTER_KEY(RooMultiVarGaussianStreamer, "multivariate_normal_dist");
DEFINE_EXPORTER_KEY(RooPoissonStreamer, "poisson_dist");
DEFINE_EXPORTER_KEY(RooGaussianStreamer, "gaussian_dist");
DEFINE_EXPORTER_KEY(RooDecayStreamer, "decay_dist");
DEFINE_EXPORTER_KEY(RooTruthModelStreamer, "truth_model_function");
DEFINE_EXPORTER_KEY(RooGaussModelStreamer, "gauss_model_function");
Expand Down Expand Up @@ -1235,6 +1264,7 @@ STATIC_EXECUTE([]() {
registerExporter<RooLogNormalStreamer>(RooLognormal::Class(), false);
registerExporter<RooMultiVarGaussianStreamer>(RooMultiVarGaussian::Class(), false);
registerExporter<RooPoissonStreamer>(RooPoisson::Class(), false);
registerExporter<RooGaussianStreamer>(RooGaussian::Class(), false);
registerExporter<RooDecayStreamer>(RooDecay::Class(), false);
registerExporter<RooTruthModelStreamer>(RooTruthModel::Class(), false);
registerExporter<RooGaussModelStreamer>(RooGaussModel::Class(), false);
Expand Down
32 changes: 32 additions & 0 deletions roofit/hs3/test/testRooFitHS3.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,38 @@ TEST(RooFitHS3, RooGaussian)
EXPECT_EQ(status, 0);
}

TEST(RooFitHS3, RooGaussianConstVarSigmaExport)
{
RooRealVar x{"x", "x", 0.0, -10.0, 10.0};
RooRealVar mean{"mean", "mean", 0.0};
mean.setConstant(true);

RooConstVar sigmaConst{"sigma_const", "sigma_const", 1.0};
RooGaussian gaussConst{"gauss_const", "gauss_const", x, mean, sigmaConst};

RooRealVar sigmaReal{"sigma_real", "sigma_real", 1.0, 0.1, 10.0};
sigmaReal.setConstant(true);
RooGaussian gaussReal{"gauss_real", "gauss_real", x, mean, sigmaReal};

RooWorkspace ws;
ws.import(gaussConst, RooFit::Silence());
ws.import(gaussReal, RooFit::RecycleConflictNodes(), RooFit::Silence());

const std::string json = RooJSONFactoryWSTool{ws}.exportJSONtoString();

EXPECT_EQ(json.find("\"sigma\":\"sigma_const\""), std::string::npos);
EXPECT_EQ(json.find("\"name\":\"sigma_const\""), std::string::npos);
EXPECT_NE(json.find("\"sigma\":1.0"), std::string::npos);

EXPECT_NE(json.find("\"sigma\":\"sigma_real\""), std::string::npos);
EXPECT_NE(json.find("\"name\":\"sigma_real\""), std::string::npos);

RooWorkspace imported;
RooJSONFactoryWSTool{imported}.importJSONfromString(json);
EXPECT_EQ(imported.obj("sigma_const"), nullptr);
EXPECT_NE(dynamic_cast<RooRealVar *>(imported.obj("sigma_real")), nullptr);
}

TEST(RooFitHS3, RooBernstein)
{
int status = validate({"RooBernstein::bernstein(x[0, 10], { a[1], 3, b[5, 0, 20] })"});
Expand Down