From f93779048c1f07843b31225c65c6a9aa9ae38fb2 Mon Sep 17 00:00:00 2001 From: starkwj Date: Wed, 31 Jul 2024 13:52:46 +0800 Subject: [PATCH] Fix overflow in softmax --- lleaves/compiler/codegen/codegen.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lleaves/compiler/codegen/codegen.py b/lleaves/compiler/codegen/codegen.py index bf2e0a3..af17a4b 100644 --- a/lleaves/compiler/codegen/codegen.py +++ b/lleaves/compiler/codegen/codegen.py @@ -317,6 +317,7 @@ def _populate_objective_func_block( llvm_copysign = builder.module.declare_intrinsic( "llvm.copysign", (DTYPE, DTYPE), ir.FunctionType(DTYPE, (DTYPE, DTYPE)) ) + llvm_maxnum = builder.module.declare_intrinsic("llvm.maxnum", (DTYPE, DTYPE), ir.FunctionType(DTYPE, (DTYPE, DTYPE))) if average_output: args[0] = builder.fdiv(args[0], get_fdtype_const(num_trees, use_fp64)) @@ -368,7 +369,10 @@ def _populate_sigmoid(alpha): elif objective == "multiclass": assert len(args) # TODO Might profit from vectorization, needs testing - result = [builder.call(llvm_exp, [arg]) for arg in args] + arg_max = args[0] + for a in args[1:]: + arg_max = builder.call(llvm_maxnum, [arg_max, a]) + result = [builder.call(llvm_exp, [builder.fsub(arg, arg_max)]) for arg in args] denominator = get_fdtype_const(0.0, use_fp64) for r in result: