Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -556,12 +556,22 @@ cc_library(
":basics",
":configs",
":mat",
":threading_context",
"//io",
"@highway//:hwy",
"@highway//:profiler",
],
)

cc_test(
name = "gemma_args_test",
srcs = ["gemma/gemma_args_test.cc"],
deps = [
":gemma_args",
"@googletest//:gtest_main", # buildcleaner: keep
],
)

cc_library(
name = "gemma_lib",
srcs = [
Expand Down Expand Up @@ -666,7 +676,6 @@ cc_library(
":gemma_args",
":gemma_lib",
":matmul_env",
":ops",
":threading_context",
":tokenizer",
"@google_benchmark//:benchmark",
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ set(GEMMA_TEST_FILES
compression/nuq_test.cc
compression/sfp_test.cc
evals/gemma_test.cc
gemma/gemma_args_test.cc
gemma/flash_attention_test.cc
gemma/tensor_info_test.cc
io/blob_store_test.cc
Expand Down
2 changes: 0 additions & 2 deletions compression/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,11 @@ cc_test(
# for test_suite.
tags = ["hwy_ops_test"],
deps = [
":distortion",
":int",
"@googletest//:gtest_main", # buildcleaner: keep
"//:test_util",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
],
)

Expand Down
15 changes: 12 additions & 3 deletions evals/benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ using json = nlohmann::json;

class BenchmarkArgs : public ArgsBase<BenchmarkArgs> {
public:
BenchmarkArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
BenchmarkArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}

Path summarize_text;
Path cross_entropy;
Expand Down Expand Up @@ -127,9 +129,16 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
} // namespace gcpp

int main(int argc, char** argv) {
gcpp::GemmaEnv env(argc, argv);
gcpp::BenchmarkArgs benchmark_args(argc, argv);
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::BenchmarkArgs benchmark_args(argc, argv, consumed);
if (gcpp::HasHelp(argc, argv)) {
args.Help();
return 0;
}
consumed.AbortIfUnconsumed();

gcpp::GemmaEnv env(args);
if (!benchmark_args.summarize_text.Empty()) {
return BenchmarkSummary(env, benchmark_args.summarize_text);
} else if (!benchmark_args.cross_entropy.Empty()) {
Expand Down
70 changes: 23 additions & 47 deletions evals/benchmark_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,35 +36,29 @@

namespace gcpp {

GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference)
GemmaEnv::GemmaEnv(const GemmaArgs& args)
: initializer_value_(gcpp::InternalInit()),
ctx_(threading),
ctx_(args.threading),
env_(ctx_),
gemma_(loader, inference, ctx_) {
gemma_(args, ctx_) {
const ModelConfig& config = gemma_.Config();
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.push_back(KVCache(config, inference, ctx_.allocator));
kv_caches_.push_back(KVCache(config, args.inference, ctx_.allocator));

if (inference.verbosity >= 2) {
ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(),
ctx_);
if (args.inference.verbosity >= 2) {
ShowConfig(args, config, gemma_.WeightReadMode(), ctx_);
}
if (inference.verbosity >= 3) env_.print_best = true;
if (inference.verbosity >= 4) env_.print_config = true;
if (args.inference.verbosity >= 3) env_.print_best = true;
if (args.inference.verbosity >= 4) env_.print_config = true;

runtime_config_ = {
.max_generated_tokens = inference.max_generated_tokens,
.temperature = inference.temperature,
.verbosity = inference.verbosity,
.max_generated_tokens = args.inference.max_generated_tokens,
.temperature = args.inference.temperature,
.verbosity = args.inference.verbosity,
};
inference.CopyTo(runtime_config_);
args.inference.CopyTo(runtime_config_);
}

GemmaEnv::GemmaEnv(int argc, char** argv)
: GemmaEnv(LoaderArgs(argc, argv), ThreadingArgs(argc, argv),
InferenceArgs(argc, argv)) {}

QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
QueryResult result;

Expand Down Expand Up @@ -234,19 +228,19 @@ static constexpr const char* CompiledConfig() {
}
}

void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference, const ModelConfig& config,
void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
const WeightsPtrs::Mode weight_read_mode,
const ThreadingContext& ctx) {
threading.Print(inference.verbosity);
loader.Print(inference.verbosity);
inference.Print(inference.verbosity);
fprintf(
stderr, "Model : %s, to_bf16 %d, mmap %d => %s\n",
config.Specifier().c_str(), static_cast<int>(loader.to_bf16),
static_cast<int>(loader.map), WeightsPtrs::ToString(weight_read_mode));
args.threading.Print(args.inference.verbosity);
args.loader.Print(args.inference.verbosity);
args.inference.Print(args.inference.verbosity);
fprintf(stderr,
"Model : %s, to_bf16 %d, mmap %d => %s\n",
config.Specifier().c_str(), static_cast<int>(args.loader.to_bf16),
static_cast<int>(args.loader.map),
WeightsPtrs::ToString(weight_read_mode));

if (inference.verbosity >= 2) {
if (args.inference.verbosity >= 2) {
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
char cpu100[100] = "unknown";
Expand All @@ -259,30 +253,12 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
"Instruction set : %s (%zu bits)\n"
"Compiled config : %s, profiler %d\n"
"Memory MiB : %4zu\n",
dt, cpu100, static_cast<int>(threading.bind),
dt, cpu100, static_cast<int>(args.threading.bind),
ctx.topology.TopologyString(), ctx.pools.PinString(),
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
ctx.cache_info.VectorBytes() * 8, CompiledConfig(),
PROFILER_ENABLED, ctx.allocator.TotalMiB());
}
}

void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference) {
std::cerr
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
"To run with pre-2025 weights, specify --tokenizer and --weights.\n"
"With the single-file weights format, specify just --weights.\n";
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights gemma2-2b-it-sfp.sbs\n";
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Threading Arguments*\n\n";
threading.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n";
}

} // namespace gcpp
13 changes: 4 additions & 9 deletions evals/benchmark_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

#include "gemma/configs.h"
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/gemma_args.h" // IWYU pragma: export
#include "gemma/tokenizer.h" // WrapAndTokenize
#include "ops/matmul.h"
#include "util/threading_context.h"
Expand All @@ -50,10 +50,8 @@ struct QueryResultAndMetrics {
// Convenience class to load a model and run inference.
class GemmaEnv {
public:
// Calls the other constructor with *Args arguments initialized from argv.
GemmaEnv(int argc, char** argv);
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference);
explicit GemmaEnv(const GemmaArgs& args);

MatMulEnv& Env() { return env_; }

size_t MaxGeneratedTokens() const {
Expand Down Expand Up @@ -137,12 +135,9 @@ class GemmaEnv {
// Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens);

void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference, const ModelConfig& config,
void ShowConfig(const GemmaArgs& args, const ModelConfig& config,
WeightsPtrs::Mode weight_read_mode,
const ThreadingContext& ctx);
void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference);

} // namespace gcpp

Expand Down
6 changes: 5 additions & 1 deletion evals/benchmarks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ BENCHMARK(BM_coding_prompt)
->UseRealTime();

int main(int argc, char** argv) {
gcpp::GemmaEnv env(argc, argv);
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();

gcpp::GemmaEnv env(args);
env.SetMaxGeneratedTokens(256);
gcpp::s_env = &env;

Expand Down
12 changes: 9 additions & 3 deletions evals/debug_prompt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ namespace gcpp {

class PromptArgs : public ArgsBase<PromptArgs> {
public:
PromptArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
PromptArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}

Path layers_output; // optional
std::string prompt;
Expand All @@ -51,11 +53,15 @@ class PromptArgs : public ArgsBase<PromptArgs> {
};

int Run(int argc, char** argv) {
PromptArgs prompt_args(argc, argv);
ConsumedArgs consumed(argc, argv);
const GemmaArgs args(argc, argv, consumed);
const PromptArgs prompt_args(argc, argv, consumed);
AbortIfInvalidArgs(prompt_args);
consumed.AbortIfUnconsumed();

json json_output;
GemmaEnv env(argc, argv);
GemmaEnv env(args);

env.MutableConfig().layers_output =
prompt_args.layers_output.Empty()
? LayersOutputFunc()
Expand Down
6 changes: 5 additions & 1 deletion evals/gemma_batch_bench.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ TEST_F(GemmaBatchBench, RandomQuestionsBatched) {

int main(int argc, char** argv) {
fprintf(stderr, "GemmaEnv setup..\n");
gcpp::GemmaEnv env(argc, argv);
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();

gcpp::GemmaEnv env(args);
gcpp::s_env = &env;

testing::InitGoogleTest(&argc, argv);
Expand Down
7 changes: 5 additions & 2 deletions evals/gemma_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

#include "evals/benchmark_helper.h"
#include "gemma/configs.h"
#include "io/io.h"
#include "hwy/base.h"
#include "hwy/tests/hwy_gtest.h"

Expand All @@ -42,7 +41,11 @@ class GemmaTest : public ::testing::Test {
// Requires argc/argv, hence do not use `SetUpTestSuite`.
static void InitEnv(int argc, char** argv) {
HWY_ASSERT(s_env == nullptr); // Should only be called once.
s_env = new GemmaEnv(argc, argv);
ConsumedArgs consumed(argc, argv);
GemmaArgs args(argc, argv, consumed);
consumed.AbortIfUnconsumed();

s_env = new GemmaEnv(args);
const gcpp::ModelConfig& config = s_env->GetGemma()->Config();
fprintf(stderr, "Using %s\n", config.Specifier().c_str());
}
Expand Down
16 changes: 11 additions & 5 deletions evals/run_mmlu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
namespace gcpp {

struct JsonArgs : public ArgsBase<JsonArgs> {
JsonArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
JsonArgs(int argc, char* argv[], ConsumedArgs& consumed) {
InitAndParse(argc, argv, consumed);
}

Path input;

Expand Down Expand Up @@ -151,10 +153,14 @@ void Run(GemmaEnv& env, JsonArgs& json) {
int main(int argc, char** argv) {
{
PROFILER_ZONE("Startup.all");
gcpp::GemmaEnv env(argc, argv);
gcpp::JsonArgs json(argc, argv);
gcpp::AbortIfInvalidArgs(json);
gcpp::Run(env, json);
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
gcpp::JsonArgs json_args(argc, argv, consumed);
gcpp::AbortIfInvalidArgs(json_args);
consumed.AbortIfUnconsumed();

gcpp::GemmaEnv env(args);
gcpp::Run(env, json_args);
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0;
Expand Down
16 changes: 8 additions & 8 deletions examples/hello_world/run.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,20 @@
#include <vector>

#include "gemma/gemma.h"
#include "gemma/gemma_args.h" // LoaderArgs
#include "gemma/gemma_args.h" // GemmaArgs
#include "gemma/tokenizer.h"
#include "util/args.h"
#include "util/threading_context.h"
#include "hwy/base.h"

int main(int argc, char** argv) {
gcpp::LoaderArgs loader(argc, argv);
gcpp::ThreadingArgs threading(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::ConsumedArgs consumed(argc, argv);
gcpp::GemmaArgs args(argc, argv, consumed);
if (gcpp::HasHelp(argc, argv)) {
loader.Help();
args.Help();
return 0;
}
consumed.AbortIfUnconsumed();

// Demonstrate constrained decoding by never outputting certain tokens.
std::set<int> reject_tokens;
Expand All @@ -49,10 +49,10 @@ int main(int argc, char** argv) {
}

// Instantiate model and KV Cache
gcpp::ThreadingContext ctx(threading);
gcpp::ThreadingContext ctx(args.threading);
gcpp::MatMulEnv env(ctx);
gcpp::Gemma gemma(loader, inference, ctx);
gcpp::KVCache kv_cache(gemma.Config(), inference, ctx.allocator);
gcpp::Gemma gemma(args, ctx);
gcpp::KVCache kv_cache(gemma.Config(), args.inference, ctx.allocator);
size_t generated = 0;

// Tokenize instructions.
Expand Down
Loading
Loading