diff --git a/roofit/hs3/src/JSONFactories_HistFactory.cxx b/roofit/hs3/src/JSONFactories_HistFactory.cxx index df59000016fc0..aac9fb65522fa 100644 --- a/roofit/hs3/src/JSONFactories_HistFactory.cxx +++ b/roofit/hs3/src/JSONFactories_HistFactory.cxx @@ -283,6 +283,36 @@ getOrCreateConstraint(RooJSONFactoryWSTool &tool, const JSONNode &mod, RooRealVa "'"); } } +double poissonTau(RooPoisson const &constraint, RooAbsArg const &gamma) +{ + auto const *mean = dynamic_cast(&constraint.getMean()); + if (!mean) { + RooJSONFactoryWSTool::error( + "Poisson gamma constraint mean is not a RooProduct: " + std::string(constraint.GetName())); + } + + for (RooAbsArg *arg : mean->servers()) { + if (arg == &gamma) { + continue; + } + + if (auto const *tau = dynamic_cast(arg)) { + return tau->getVal(); + } + + // Imported workspaces can sometimes represent + // constants as constant RooRealVars. + if (auto const *real = dynamic_cast(arg)) { + if (real->isConstant() || endsWith(std::string(real->GetName()), "_tau")) { + return real->getVal(); + } + } + } + + RooJSONFactoryWSTool::error( + "Could not find tau component in Poisson gamma constraint mean: " + std::string(constraint.GetName())); + return std::numeric_limits::quiet_NaN(); +} bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet const &varlist, RooAbsArg const *mcStatObject, const std::string &fprefix, const JSONNode &p, @@ -334,6 +364,7 @@ bool importHistSample(RooJSONFactoryWSTool &tool, RooDataHist &dh, RooArgSet con // this is dealt with at a different place, ignore it for now } else if (modtype == "normfactor") { RooRealVar &constrParam = getOrCreate(ws, sysname, 1., -3, 5); + constrParam.setError(0.0); normElems.add(constrParam); if (mod.has_child("constraint_name") || mod.has_child("constraint_type")) { // for norm factors, constraints are optional @@ -1060,7 +1091,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons if (constraint) { sample.barlowBeestonLightConstraintType = constraint->IsA(); if (RooPoisson *constraint_p = dynamic_cast(constraint)) { - double erel = 1. / std::sqrt(constraint_p->getX().getVal()); + double erel = 1. / std::sqrt(poissonTau(*constraint_p, *g)); channel.rel_errors[idx] = erel; } else if (RooGaussian *constraint_g = dynamic_cast(constraint)) { double erel = constraint_g->getSigma().getVal() / constraint_g->getMean().getVal(); @@ -1094,7 +1125,7 @@ Channel readChannel(RooJSONFactoryWSTool *tool, const std::string &pdfname, cons if (!constraint) { sys.constraints.push_back(0.0); } else if (auto constraint_p = dynamic_cast(constraint)) { - sys.constraints.push_back(1. / std::sqrt(constraint_p->getX().getVal())); + sys.constraints.push_back(1. / std::sqrt(poissonTau(*constraint_p, *g))); if (!sys.constraint) { sys.constraintType = RooPoisson::Class(); } diff --git a/roofit/hs3/src/JSONFactories_RooFitCore.cxx b/roofit/hs3/src/JSONFactories_RooFitCore.cxx index c3b771a9ee033..2e8e825602a96 100644 --- a/roofit/hs3/src/JSONFactories_RooFitCore.cxx +++ b/roofit/hs3/src/JSONFactories_RooFitCore.cxx @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -29,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -878,6 +880,32 @@ class RooPoissonStreamer : public RooFit::JSONIO::Exporter { } }; +class RooGaussianStreamer : public RooFit::JSONIO::Exporter { +public: + std::string const &key() const override; + bool autoExportDependants() const override { return false; } + bool exportObject(RooJSONFactoryWSTool *tool, const RooAbsArg *func, JSONNode &elem) const override + { + auto *pdf = static_cast(func); + elem["type"] << key(); + writeArg(tool, elem["x"], pdf->getX()); + writeArg(tool, elem["mean"], pdf->getMean()); + writeArg(tool, elem["sigma"], pdf->getSigma()); + return true; + } + +private: + static void writeArg(RooJSONFactoryWSTool *tool, JSONNode &node, RooAbsReal const &arg) + { + if (auto const *constant = dynamic_cast(&arg)) { + node << constant->getVal(); + } else { + node << arg.GetName(); + tool->queueExport(arg); + } + } +}; + class RooDecayStreamer : public RooFit::JSONIO::Exporter { public: std::string const &key() const override; @@ -1171,6 +1199,7 @@ DEFINE_EXPORTER_KEY(RooHistPdfStreamer, "histogram_dist"); DEFINE_EXPORTER_KEY(RooLogNormalStreamer, "lognormal_dist"); DEFINE_EXPORTER_KEY(RooMultiVarGaussianStreamer, "multivariate_normal_dist"); DEFINE_EXPORTER_KEY(RooPoissonStreamer, "poisson_dist"); +DEFINE_EXPORTER_KEY(RooGaussianStreamer, "gaussian_dist"); DEFINE_EXPORTER_KEY(RooDecayStreamer, "decay_dist"); DEFINE_EXPORTER_KEY(RooTruthModelStreamer, "truth_model_function"); DEFINE_EXPORTER_KEY(RooGaussModelStreamer, "gauss_model_function"); @@ -1235,6 +1264,7 @@ STATIC_EXECUTE([]() { registerExporter(RooLognormal::Class(), false); registerExporter(RooMultiVarGaussian::Class(), false); registerExporter(RooPoisson::Class(), false); + registerExporter(RooGaussian::Class(), false); registerExporter(RooDecay::Class(), false); registerExporter(RooTruthModel::Class(), false); registerExporter(RooGaussModel::Class(), false); diff --git a/roofit/hs3/test/testRooFitHS3.cxx b/roofit/hs3/test/testRooFitHS3.cxx index 86481361a5812..6b788d3157d02 100644 --- a/roofit/hs3/test/testRooFitHS3.cxx +++ b/roofit/hs3/test/testRooFitHS3.cxx @@ -259,6 +259,38 @@ TEST(RooFitHS3, RooGaussian) EXPECT_EQ(status, 0); } +TEST(RooFitHS3, RooGaussianConstVarSigmaExport) +{ + RooRealVar x{"x", "x", 0.0, -10.0, 10.0}; + RooRealVar mean{"mean", "mean", 0.0}; + mean.setConstant(true); + + RooConstVar sigmaConst{"sigma_const", "sigma_const", 1.0}; + RooGaussian gaussConst{"gauss_const", "gauss_const", x, mean, sigmaConst}; + + RooRealVar sigmaReal{"sigma_real", "sigma_real", 1.0, 0.1, 10.0}; + sigmaReal.setConstant(true); + RooGaussian gaussReal{"gauss_real", "gauss_real", x, mean, sigmaReal}; + + RooWorkspace ws; + ws.import(gaussConst, RooFit::Silence()); + ws.import(gaussReal, RooFit::RecycleConflictNodes(), RooFit::Silence()); + + const std::string json = RooJSONFactoryWSTool{ws}.exportJSONtoString(); + + EXPECT_EQ(json.find("\"sigma\":\"sigma_const\""), std::string::npos); + EXPECT_EQ(json.find("\"name\":\"sigma_const\""), std::string::npos); + EXPECT_NE(json.find("\"sigma\":1.0"), std::string::npos); + + EXPECT_NE(json.find("\"sigma\":\"sigma_real\""), std::string::npos); + EXPECT_NE(json.find("\"name\":\"sigma_real\""), std::string::npos); + + RooWorkspace imported; + RooJSONFactoryWSTool{imported}.importJSONfromString(json); + EXPECT_EQ(imported.obj("sigma_const"), nullptr); + EXPECT_NE(dynamic_cast(imported.obj("sigma_real")), nullptr); +} + TEST(RooFitHS3, RooBernstein) { int status = validate({"RooBernstein::bernstein(x[0, 10], { a[1], 3, b[5, 0, 20] })"});