Add support to non-fp32 input dtype in level3/33_VanillaRNN.py and level3/35_LSTM.py#131
Open
mzweilin wants to merge 1 commit intoScalingIntelligence:mainfrom
Open
Add support to non-fp32 input dtype in level3/33_VanillaRNN.py and level3/35_LSTM.py#131mzweilin wants to merge 1 commit intoScalingIntelligence:mainfrom
level3/33_VanillaRNN.py and level3/35_LSTM.py#131mzweilin wants to merge 1 commit intoScalingIntelligence:mainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR fixes a bug in two problems
level3/33_VanillaRNN.pyandlevel3/35_LSTM.pythatModeluses hard-coded FP32 tensors.It is related to #79 and #80. You will find the bug if you try to run the agent on the two problems.
Please consider updating the dataset in HuggingFace too: https://huggingface.co/datasets/ScalingIntelligence/KernelBench
An alternative fix could be moving the random tensors from
Modeltoget_inputs().How to reproduce the bug
$ python -i KernelBench/level3/33_VanillaRNN.pyRuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16$ python -i KernelBench/level3/35_LSTM.pyRuntimeError: could not create a primitive descriptor for the LSTM forward propagation primitive. Run workload with environment variable ONEDNN_VERBOSE=all to get additional diagnostic information.