From d7c1324bdd6ac7dcafd015d7883d65587d2862eb Mon Sep 17 00:00:00 2001 From: stevenfontanella Date: Thu, 5 Feb 2026 21:49:25 +0000 Subject: [PATCH] Add support for either in wast --- scripts/test/shared.py | 11 +-- src/parser/wast-parser.cpp | 24 ++++- src/parser/wat-parser.h | 8 +- src/support/result.h | 20 ++-- src/tools/wasm-shell.cpp | 183 +++++++++++++++++++++++++------------ src/tools/wasm2js.cpp | 12 ++- 6 files changed, 179 insertions(+), 79 deletions(-) diff --git a/scripts/test/shared.py b/scripts/test/shared.py index 32516459283..41771e8a4c9 100644 --- a/scripts/test/shared.py +++ b/scripts/test/shared.py @@ -395,7 +395,7 @@ def get_tests(test_dir, extensions=[], recursive=False): # Test invalid 'elem.wast', - # Requires wast `either` support + # Requires scoping of `register` statements within `thread` blocks 'threads/thread.wast', # Requires better support for multi-threaded tests @@ -453,12 +453,9 @@ def get_tests(test_dir, extensions=[], recursive=False): 'type-subtyping.wast', # ShellExternalInterface::callTable does not handle subtyping 'memory64.wast', # Requires validations on the max memory size 'imports3.wast', # Requires better checking of exports from the special "spectest" module - 'i16x8_relaxed_q15mulr_s.wast', # Requires wast `either` support - 'i8x16_relaxed_swizzle.wast', # Requires wast `either` support - 'relaxed_dot_product.wast', # Requires wast `either` support - 'relaxed_laneselect.wast', # Requires wast `either` support - 'relaxed_madd_nmadd.wast', # Requires wast `either` support - 'relaxed_min_max.wast', # Requires wast `either` support + 'relaxed_dot_product.wast', # i16x8.relaxed_dot_i8x16_i7x16_s instruction not supported + 'relaxed_laneselect.wast', # i8x16.relaxed_laneselect instruction not supported + 'relaxed_min_max.wast', # Non-canonical NaN from f32x4.relaxed_min 'simd_const.wast', # Hex float constant not recognized as out of range 'simd_conversions.wast', # Promoted NaN should be canonical 'simd_f32x4.wast', # Min of 0 and NaN should give a canonical NaN diff --git a/src/parser/wast-parser.cpp b/src/parser/wast-parser.cpp index a7b0cba870a..5044b36f23a 100644 --- a/src/parser/wast-parser.cpp +++ b/src/parser/wast-parser.cpp @@ -288,12 +288,30 @@ Result result(Lexer& in) { return in.err("unrecognized result"); } +Result eitherResult(Lexer& in) { + if (in.takeSExprStart("either"sv)) { + ResultAlternatives alternatives; + do { + auto r = result(in); + CHECK_ERR(r); + + alternatives.push_back(*std::move(r)); + } while (!in.takeRParen()); + + return alternatives; + } + + auto r = result(in); + CHECK_ERR(r); + return ResultAlternatives{*std::move(r)}; +} + Result results(Lexer& in) { ExpectedResults res; while (!in.peekRParen()) { - auto r = result(in); + auto r = eitherResult(in); CHECK_ERR(r); - res.emplace_back(std::move(*r)); + res.emplace_back(*std::move(r)); } return res; } @@ -612,7 +630,7 @@ Result wast(Lexer& in) { return cmds; } CHECK_ERR(cmd); - cmds.push_back(ScriptEntry{std::move(*cmd), line}); + cmds.push_back(ScriptEntry{*std::move(cmd), line}); } return cmds; } diff --git a/src/parser/wat-parser.h b/src/parser/wat-parser.h index 8190db4e100..380d0f0be2c 100644 --- a/src/parser/wat-parser.h +++ b/src/parser/wat-parser.h @@ -79,7 +79,13 @@ using LaneResults = std::vector; using ExpectedResult = std::variant; -using ExpectedResults = std::vector; +using ResultAlternatives = std::vector; + +// The WAST spec states that `either`s maybe be nested arbitrarily e.g. +// (either (either "a" "b") (either "a" "c")) +// but we store this flattened since there's no way to tell the difference +// anyway. +using ExpectedResults = std::vector; struct AssertReturn { Action action; diff --git a/src/support/result.h b/src/support/result.h index 7cd360d533a..f34f6300759 100644 --- a/src/support/result.h +++ b/src/support/result.h @@ -36,21 +36,22 @@ struct Err { // Check a Result or MaybeResult for error and return the error if it exists. #define CHECK_ERR(val) \ if (auto _val = (val); auto err = _val.getErr()) { \ - return Err{*err}; \ + return (typename decltype(_val)::ErrorType)(*err); \ } // Represent a result of type T or an error message. -template struct [[nodiscard]] Result { - std::variant val; - - Result(Result& other) = default; - Result(Result&& other) = default; - Result(const Err& e) : val(std::in_place_type, e) {} - Result(Err&& e) : val(std::in_place_type, std::move(e)) {} +template struct [[nodiscard]] Result { + using ErrorType = E; + std::variant val; + + Result(Result& other) = default; + Result(Result&& other) = default; + Result(const E& e) : val(std::in_place_type, e) {} + Result(E&& e) : val(std::in_place_type, std::move(e)) {} template Result(U&& u) : val(std::in_place_type, std::forward(u)) {} - Err* getErr() { return std::get_if(&val); } + E* getErr() { return std::get_if(&val); } T& operator*() { return *std::get_if(&val); } T* operator->() { return std::get_if(&val); } }; @@ -58,6 +59,7 @@ template struct [[nodiscard]] Result { // Represent an optional result of type T or an error message. template struct [[nodiscard]] MaybeResult { std::variant val; + using ErrorType = Err; MaybeResult() : val(None{}) {} MaybeResult(MaybeResult& other) = default; diff --git a/src/tools/wasm-shell.cpp b/src/tools/wasm-shell.cpp index 32c1b98ad4e..849bc186462 100644 --- a/src/tools/wasm-shell.cpp +++ b/src/tools/wasm-shell.cpp @@ -319,13 +319,13 @@ struct Shell { switch (nan.kind) { case NaNKind::Canonical: if (val.type != nan.type || !val.isCanonicalNaN()) { - err << "expected canonical " << nan.type << " NaN, got " << val; + err << "canonical " << nan.type; return Err{err.str()}; } break; case NaNKind::Arithmetic: if (val.type != nan.type || !val.isArithmeticNaN()) { - err << "expected arithmetic " << nan.type << " NaN, got " << val; + err << "arithmetic " << nan.type; return Err{err.str()}; } break; @@ -333,17 +333,17 @@ struct Shell { return Ok{}; } - Result<> checkLane(Literal val, LaneResult expected, Index index) { + Result<> checkLane(Literal val, LaneResult expected) { std::stringstream err; if (auto* e = std::get_if(&expected)) { if (*e != val) { - err << "expected " << *e << ", got " << val << " at lane " << index; + err << *e; return Err{err.str()}; } } else if (auto* nan = std::get_if(&expected)) { auto check = checkNaN(val, *nan); if (auto* e = check.getErr()) { - err << e->msg << " at lane " << index; + err << e->msg; return Err{err.str()}; } } else { @@ -352,6 +352,83 @@ struct Shell { return Ok{}; } + struct AlternativeErr { + std::string expected; + int lane = -1; + }; + + Result matchAlternative(const Literal& val, + const ExpectedResult& expected, + bool isAlternative) { + std::stringstream err; + + if (auto* v = std::get_if(&expected)) { + if (val != *v) { + if (val.type.isVector() && v->type.isVector() && isAlternative) { + auto valLanes = val.getLanesI32x4(); + auto expLanes = v->getLanesI32x4(); + for (int i = 0; i < 4; ++i) { + if (valLanes[i] != expLanes[i]) { + err << "0x" << std::setfill('0') << std::setw(8) << std::hex + << expLanes[i] << std::dec; + return AlternativeErr{err.str(), i}; + } + } + } + err << *v; + return AlternativeErr{err.str()}; + } + } else if (auto* ref = std::get_if(&expected)) { + if (!val.type.isRef() || + !HeapType::isSubType(val.type.getHeapType(), ref->type)) { + err << ref->type; + return AlternativeErr{err.str()}; + } + } else if ([[maybe_unused]] auto* nullRef = + std::get_if(&expected)) { + if (!val.isNull()) { + err << "ref.null"; + return AlternativeErr{err.str()}; + } + } else if (auto* nan = std::get_if(&expected)) { + auto check = checkNaN(val, *nan); + if (auto* e = check.getErr()) { + err << e->msg; + return AlternativeErr{err.str()}; + } + } else if (auto* lanes = std::get_if(&expected)) { + switch (lanes->size()) { + case 4: { + auto vals = val.getLanesF32x4(); + for (int i = 0; i < 4; ++i) { + auto check = checkLane(vals[i], (*lanes)[i]); + if (auto* e = check.getErr()) { + err << e->msg; + return AlternativeErr{err.str(), i}; + } + } + break; + } + case 2: { + auto vals = val.getLanesF64x2(); + for (int i = 0; i < 2; ++i) { + auto check = checkLane(vals[i], (*lanes)[i]); + if (auto* e = check.getErr()) { + err << e->msg; + return AlternativeErr{err.str(), i}; + } + } + break; + } + default: + WASM_UNREACHABLE("unexpected number of lanes"); + } + } else { + WASM_UNREACHABLE("unexpected expectation"); + } + return Ok{}; + } + Result<> assertReturn(AssertReturn& assn) { std::stringstream err; auto result = doAction(assn.action); @@ -374,63 +451,55 @@ struct Shell { return ss.str(); }; - Literal val = (*values)[i]; - auto& expected = assn.expected[i]; - if (auto* v = std::get_if(&expected)) { - if (val != *v) { - err << "expected " << *v << ", got " << val << atIndex(); - return Err{err.str()}; - } - } else if (auto* ref = std::get_if(&expected)) { - if (!val.type.isRef() || - !HeapType::isSubType(val.type.getHeapType(), ref->type)) { - err << "expected " << ref->type << " reference, got " << val - << atIndex(); - return Err{err.str()}; - } - } else if ([[maybe_unused]] auto* nullRef = - std::get_if(&expected)) { - if (!val.isNull()) { - err << "expected ref.null, got " << val << atIndex(); - return Err{err.str()}; + // non-either case + if (assn.expected[i].size() == 1) { + auto result = matchAlternative( + (*values)[i], assn.expected[i][0], /*isAlternative=*/false); + if (auto* e = result.getErr()) { + std::stringstream ss; + ss << "expected " << e->expected << ", got " << (*values)[i]; + if (e->lane != -1) { + ss << " at lane " << e->lane; + } + ss << atIndex(); + return Err{ss.str()}; } - } else if (auto* nan = std::get_if(&expected)) { - auto check = checkNaN(val, *nan); - if (auto* e = check.getErr()) { - err << e->msg << atIndex(); - return Err{err.str()}; + continue; + } + + // either case + bool success = false; + std::vector expecteds; + int failedLane = -1; + for (const auto& alternative : assn.expected[i]) { + auto result = + matchAlternative((*values)[i], alternative, /*isAlternative=*/true); + if (!result.getErr()) { + success = true; + break; } - } else if (auto* lanes = std::get_if(&expected)) { - switch (lanes->size()) { - case 4: { - auto vals = val.getLanesF32x4(); - for (Index i = 0; i < 4; ++i) { - auto check = checkLane(vals[i], (*lanes)[i], i); - if (auto* e = check.getErr()) { - err << e->msg << atIndex(); - return Err{err.str()}; - } - } - break; - } - case 2: { - auto vals = val.getLanesF64x2(); - for (Index i = 0; i < 2; ++i) { - auto check = checkLane(vals[i], (*lanes)[i], i); - if (auto* e = check.getErr()) { - err << e->msg << atIndex(); - return Err{err.str()}; - } - } - break; - } - default: - WASM_UNREACHABLE("unexpected number of lanes"); + + auto* e = result.getErr(); + expecteds.push_back(e->expected); + if (failedLane == -1 && e->lane != -1) { + failedLane = e->lane; } - } else { - WASM_UNREACHABLE("unexpected expectation"); } + if (success) { + continue; + } + std::stringstream ss; + ss << "Expected one of (" << String::join(expecteds, " | ") << ")"; + if (failedLane != -1) { + ss << " at lane " << failedLane; + } + ss << " but got " << (*values)[i]; + + ss << atIndex(); + + return Err{ss.str()}; } + return Ok{}; } diff --git a/src/tools/wasm2js.cpp b/src/tools/wasm2js.cpp index 4821070103a..03be0f7132a 100644 --- a/src/tools/wasm2js.cpp +++ b/src/tools/wasm2js.cpp @@ -605,7 +605,15 @@ Ref AssertionEmitter::emitAssertReturnFunc(AssertReturn& assn, Name asmModule) { if (assn.expected.size() > 1) { Fatal() << "multivalue assert_return not supported"; + return {}; } + for (const auto& alternatives : assn.expected) { + if (alternatives.size() > 1) { + Fatal() << "(either ...) not supported"; + return {}; + } + } + auto* invoke = std::get_if(&assn.action); if (!invoke) { Fatal() << "only invoke actions are supported in assert_return"; @@ -619,7 +627,7 @@ Ref AssertionEmitter::emitAssertReturnFunc(AssertReturn& assn, } else { body = actual; } - } else if (auto* expectedVal = std::get_if(&assn.expected[0])) { + } else if (auto* expectedVal = std::get_if(&assn.expected[0][0])) { if (!expectedVal->type.isBasic()) { Fatal() << "unsupported type in assert_return: " << expectedVal->type; } @@ -648,7 +656,7 @@ Ref AssertionEmitter::emitAssertReturnFunc(AssertReturn& assn, Fatal() << "Unhandled type in assert: " << expected->type; } } - } else if (std::get_if(&assn.expected[0])) { + } else if (std::get_if(&assn.expected[0][0])) { body = builder.makeCall("isNaN", {actual}, Type::i32); } std::unique_ptr testFunc(