-
Notifications
You must be signed in to change notification settings - Fork 527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Add OnnxToTorch lowering for Onnx.NegativeLogLikelihoodLoss Op #3380
Conversation
Hi @123epsilon @ScottTodd @andfau-amd, do we need this PR for this op lowering or not? Otherwise, we should close this and get this op working through @andfau-amd's patch. |
It does look like my patch can expand this op, e.g.:
However, I haven't so far been able to test whether this expansion is actually correct/useful. FWIW, it does seem like TorchOnnxToTorch doesn't choke on it:
|
To me, the above expansion does not seem correct, since I don't see the |
Well, I don't see any mention of Log in the examples here: https://onnx.ai/onnx/operators/onnx__NegativeLogLikelihoodLoss.html It seems like the "log" part is not actually in this operation, but rather the inputs are meant to already be logarithmic? By the way, it seems like the imported ONNX node tests pass for this op: #3409 (comment) |
Sorry for the confusion, it's a fault at my end. Based on the op definition lowering looks fine. |
Maybe the expansion from this PR is better than the one in mine, since it uses the dedicated Torch NLL op? |
I think taking a conservative approach, we can keep the op lowering now, and can delete it later once the function expansion thing is in and stable? |
Sounds good to me. |
This implements the Onnx.NegativeLogLikelihoodLoss op using the signature provided here by replacing it with a
NLLLossForward
op.Additionally, I included a helper function
get_loss_reduction_enum
to convert from a stringreduction
parameter to the corresponding intended integer value since this is an operation that will be reused for any loss function module. This differs fromget_reduction_enum
inTorchUpstream.cpp
which handles thereduce
parameter fromscatter_reduce
type operations.