Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix overflow in softmax #82

Closed
wants to merge 1 commit into from
Closed

Fix overflow in softmax #82

wants to merge 1 commit into from

Conversation

starkwj
Copy link

@starkwj starkwj commented Jul 31, 2024

In "multiclass" prediction, lleaves generates NaN for some test cases.
This is caused by the current implementation of softmax calculation, which generate Inf in exp:

result = [builder.call(llvm_exp, [arg]) for arg in args]

This PR fixes this overflow issue according to softmax in LightGBM, where each arg of exp should minus the max of args: https://github.com/microsoft/LightGBM/blob/2f60e115a899a09d91e3710fce085444d3c740a3/tests/python_package_test/utils.py#L136-L139

@@ -317,6 +317,7 @@ def _populate_objective_func_block(
llvm_copysign = builder.module.declare_intrinsic(
"llvm.copysign", (DTYPE, DTYPE), ir.FunctionType(DTYPE, (DTYPE, DTYPE))
)
llvm_maxnum = builder.module.declare_intrinsic("llvm.maxnum", (DTYPE, DTYPE), ir.FunctionType(DTYPE, (DTYPE, DTYPE)))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why you're using maxnum instead of maximum? Sounds like maximum propagates NaNs while maxnum doesn't, and when in doubt I'd probably rather propagate them.

Copy link
Author

@starkwj starkwj Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I agree with you.
However, maximum cannot work for me (and maybe a common case). When I try maximum, It turns out as "LLVM ERROR: Cannot select: t65: f64 = fmaximum t55, t64".
I tried to find out the cause, listed below. I'am not familiar with llvm, so I cannot ensure the below is correct.

  1. The issue is simliar to Cannot select llvm.{min,max}imum.{f32,f64} llvm/llvm-project#53353 . From replies, it seems that maximum of float has been supported not long ago: llvm/llvm-project@a82d27a
  2. For llvmlite, I installed the latest version 0.43.0 by pre-built binary (pip or conda), and the required LLVM is integrated by llvmlite binary (https://llvmlite.readthedocs.io/en/latest/admin-guide/install.html). I believe this is a common usage. The integrated LLVM doesn't seems to support the float version of maximum.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh yeah it seems this is not implement. I guess it doesn't super matter, since we're currently doing this elementwise anyway (instead of e.g. using llvm.vector.reduce.fmax) you could just implement it via fcmp followed by select.

@siboehm
Copy link
Owner

siboehm commented Jul 31, 2024

Good catch! I'm confused why the implementation was so shitty in the first place

@siboehm
Copy link
Owner

siboehm commented Aug 3, 2024

I fixed this in #83 in a slightly simpler way. Do you want to test if it works for you before I make a new release? Thanks a lot for finding the bug and pointing it out!

@starkwj
Copy link
Author

starkwj commented Aug 4, 2024

I fixed this in #83 in a slightly simpler way. Do you want to test if it works for you before I make a new release? Thanks a lot for finding the bug and pointing it out!

Thanks a lot!
But I have a doubt:

max_val = args[0]
for arg in args[1:]:
max_val = builder.select(
builder.fcmp_ordered(">", arg, max_val), arg, max_val
)

If args[0] is not a NaN, but args[1] is NaN, fcmp_ordered returns false and select chooses args[0], so NaN is still not propagated.

However, I just noticed that the following statements of exp and fsub can still propagate NaN in args:
result = [builder.call(llvm_exp, [builder.fsub(arg, arg_max)]) for arg in args].
Thus, maybe it's not necessary to handle NaN when selecting max?

@siboehm
Copy link
Owner

siboehm commented Aug 4, 2024

Yes I think you're right. I don't think there's a way to get NaNs from the max without doing an explicit NaN check. But since the denom will be NaN if any of the values is NaN this won't really matter

@siboehm
Copy link
Owner

siboehm commented Aug 4, 2024

Also this has gone out as v1.2.3 on PyPI, conda-forge should be releasing later today

@siboehm siboehm closed this Aug 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants