Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion check.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ def check_expected(actual, expected, stdout=None):

UNSPLITTABLE_TESTS = [Path(x) for x in [
"spec/testsuite/instance.wast",
"spec/instance.wast"]]
"spec/instance.wast",

# TODO: support module splitting for (thread ...) blocks
"spec/threads/*"]]


def is_splittable(wast: Path):
Expand Down
9 changes: 9 additions & 0 deletions scripts/test/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,15 @@ def get_tests(test_dir, extensions=[], recursive=False):

# Test invalid
'elem.wast',

# Requires wast `either` support
'threads/thread.wast',

# Requires better support for multi-threaded tests
'threads/wait_notify.wast',

# Non-natural alignment is invalid for atomic operations
'threads/atomic.wast',
]
SPEC_TESTSUITE_PROPOSALS_TO_SKIP = [
'custom-page-sizes',
Expand Down
69 changes: 69 additions & 0 deletions src/parser/wast-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ using namespace std::string_view_literals;

namespace {

Result<WASTCommand> command(Lexer& in);

Result<Literal> const_(Lexer& in) {
if (in.takeSExprStart("ref.extern"sv)) {
auto n = in.takeI32();
Expand Down Expand Up @@ -496,12 +498,79 @@ MaybeResult<ModuleInstantiation> instantiation(Lexer& in) {
return ModuleInstantiation{moduleId, instanceId};
}

// (thread name (shared (module name))? command*)
MaybeResult<ThreadBlock> thread_(Lexer& in) {
if (!in.takeSExprStart("thread"sv)) {
return {};
}

auto name = in.takeID();
if (!name) {
return in.err("expected thread name");
}

std::optional<Name> sharedModule;
if (in.takeSExprStart("shared"sv)) {
if (!in.takeSExprStart("module"sv)) {
return in.err("expected module keyword in (shared ...) block");
}

auto modName = in.takeID();
if (!modName) {
return in.err("expected module name after (shared (module ...))");
}
sharedModule = *modName;

if (!in.takeRParen()) {
return in.err("expected end of shared module");
}
if (!in.takeRParen()) {
return in.err("expected end of (shared ...) expression");
}
}

std::vector<ScriptEntry> commands;
while (!in.peekRParen() && !in.empty()) {
size_t line = in.position().line;
auto cmd = command(in);
CHECK_ERR(cmd);
commands.push_back({std::move(*cmd), line});
}
if (!in.takeRParen()) {
return in.err("expected end of thread");
}
return ThreadBlock{*name, sharedModule, std::move(commands)};
}

// (wait name)
MaybeResult<Wait> wait_(Lexer& in) {
if (!in.takeSExprStart("wait"sv)) {
return {};
}
auto name = in.takeID();
if (!name) {
return in.err("expected thread name in wait");
}
if (!in.takeRParen()) {
return in.err("expected end of wait");
}
return Wait{*name};
}

// instantiate | module | register | action | assertion
Result<WASTCommand> command(Lexer& in) {
if (auto cmd = register_(in)) {
CHECK_ERR(cmd);
return *cmd;
}
if (auto cmd = thread_(in)) {
CHECK_ERR(cmd);
return *cmd;
}
if (auto cmd = wait_(in)) {
CHECK_ERR(cmd);
return *cmd;
}
if (auto cmd = maybeAction(in)) {
CHECK_ERR(cmd);
return *cmd;
Expand Down
24 changes: 20 additions & 4 deletions src/parser/wat-parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,32 @@ struct ModuleInstantiation {
std::optional<Name> instanceName;
};

using WASTCommand =
std::variant<WASTModule, Register, Action, Assertion, ModuleInstantiation>;
struct ScriptEntry;
using WASTScript = std::vector<ScriptEntry>;

struct ThreadBlock {
Name name;
std::optional<Name> sharedModule;
WASTScript commands;
};

struct Wait {
Name thread;
};

using WASTCommand = std::variant<WASTModule,
Register,
Action,
Assertion,
ModuleInstantiation,
ThreadBlock,
Wait>;

struct ScriptEntry {
WASTCommand cmd;
size_t line;
};

using WASTScript = std::vector<ScriptEntry>;

Result<WASTScript> parseScript(std::string_view in);

} // namespace wasm::WATParser
Expand Down
143 changes: 142 additions & 1 deletion src/tools/wasm-shell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ struct Shell {

Options& options;

struct ThreadState {
Name name;
std::vector<WATParser::ScriptEntry> commands;
size_t pc = 0;
bool isSuspended = false;
std::shared_ptr<ModuleRunner> instance = nullptr;
std::shared_ptr<ContData> suspendedCont = nullptr;
bool done = false;
};
std::vector<ThreadState> activeThreads;

Shell(Options& options) : options(options) { buildSpectestModule(); }

Result<> run(WASTScript& script) {
Expand Down Expand Up @@ -93,11 +104,134 @@ struct Shell {
} else if (auto* instantiateModule =
std::get_if<ModuleInstantiation>(&cmd)) {
return doInstantiate(*instantiateModule);
} else if (auto* thread = std::get_if<ThreadBlock>(&cmd)) {
return doThread(*thread);
} else if (auto* wait = std::get_if<Wait>(&cmd)) {
return doWait(*wait);
} else {
WASM_UNREACHABLE("unexpected command");
}
}

// Run threads in a blocking manner for now.
// TODO: yield on blocking instructions e.g. memory.atomic.wait32.
Result<> doThread(ThreadBlock& thread) {
ThreadState state;
state.name = thread.name;
state.commands = thread.commands;
activeThreads.push_back(std::move(state));
return Ok{};
}

Result<> doWait(Wait& wait) {
bool found = false;
for (auto& t : activeThreads) {
if (t.name == wait.thread) {
found = true;
break;
}
}
if (!found) {
return Err{"wait called for unknown thread"};
}

// Round-robin execution
while (true) {
bool anyProgress = false;
bool targetDone = false;

for (auto& t : activeThreads) {
if (t.done) {
if (t.name == wait.thread)
targetDone = true;
continue;
}

if (t.isSuspended) {
// Check if it's still waiting. WaitQueue sets `isWaiting` to false
// when notified.
bool stillWaiting = t.suspendedCont && t.suspendedCont->isWaiting;

if (!stillWaiting) {
// It was woken up! We need to resume it.
t.isSuspended = false;
Flow flow;
try {
flow = t.instance->resumeContinuation(t.suspendedCont);
} catch (TrapException&) {
std::cerr << "Thread " << t.name << " trapped upon resume\n";
t.done = true;
anyProgress = true;
continue;
} catch (...) {
WASM_UNREACHABLE("unexpected error during resume");
}
t.suspendedCont = nullptr;

if (flow.breakTo == THREAD_SUSPEND_FLOW) {
// Suspended again
t.isSuspended = true;
t.suspendedCont = t.instance->getSuspendedContinuation();
anyProgress = true;
} else if (flow.suspendTag) {
t.instance->clearContinuationStore();
t.done = true; // unhandled suspension
anyProgress = true;
} else {
t.pc++; // Completed the command that originally suspended!
anyProgress = true;
}
}
} else {
// Normal execution of the next command.
if (t.pc < t.commands.size()) {
auto& cmd = t.commands[t.pc].cmd;
if (auto* act = std::get_if<Action>(&cmd)) {
auto result = doAction(*act);
if (std::get_if<ThreadSuspendResult>(&result)) {
t.isSuspended = true;
if (auto* invoke = std::get_if<InvokeAction>(act)) {
t.instance =
instances[invoke->base ? *invoke->base : lastInstance];
t.suspendedCont = t.instance->getSuspendedContinuation();
}
anyProgress = true;
} else {
t.pc++;
anyProgress = true;
}
} else {
// Not an action, just run it (e.g. module instantiation or
// assertions inside thread)
auto res = runCommand(cmd);
if (res.getErr()) {
std::cerr << "Thread " << t.name
<< " error: " << res.getErr()->msg << "\n";
t.done = true;
} else {
t.pc++;
anyProgress = true;
}
}
} else {
t.done = true;
anyProgress = true; // finishing counts as progress
}
}
}

if (targetDone) {
break;
}

if (!anyProgress) {
// Find if target is still suspended
return Err{"deadlock! no threads can make progress"};
}
}
return Ok{};
}

Result<std::shared_ptr<Module>> makeModule(WASTModule& mod) {
std::shared_ptr<Module> wasm;
if (auto* quoted = std::get_if<QuotedModule>(&mod.module)) {
Expand Down Expand Up @@ -223,11 +357,13 @@ struct Shell {
struct HostLimitResult {};
struct ExceptionResult {};
struct SuspensionResult {};
struct ThreadSuspendResult {};
using ActionResult = std::variant<Literals,
TrapResult,
HostLimitResult,
ExceptionResult,
SuspensionResult>;
SuspensionResult,
ThreadSuspendResult>;

std::string resultToString(ActionResult& result) {
if (std::get_if<TrapResult>(&result)) {
Expand All @@ -238,6 +374,8 @@ struct Shell {
return "exception";
} else if (std::get_if<SuspensionResult>(&result)) {
return "suspension";
} else if (std::get_if<ThreadSuspendResult>(&result)) {
return "thread_suspend";
} else if (auto* vals = std::get_if<Literals>(&result)) {
std::stringstream ss;
ss << *vals;
Expand Down Expand Up @@ -267,6 +405,9 @@ struct Shell {
} catch (...) {
WASM_UNREACHABLE("unexpected error");
}
if (flow.breakTo == THREAD_SUSPEND_FLOW) {
return ThreadSuspendResult{};
}
if (flow.suspendTag) {
// This is an unhandled suspension. Handle it here - clear the
// suspension state - so nothing else is affected.
Expand Down
Loading
Loading