diff --git a/kiwi.lua b/kiwi.lua index 31638e2..7c03db2 100644 --- a/kiwi.lua +++ b/kiwi.lua @@ -96,6 +96,8 @@ struct KiwiVar { void kiwi_var_free(KiwiVar* var); +void kiwi_expression_add_term(const KiwiExpression* expr, KiwiVar* var, double coeff, KiwiExpression* out); +void kiwi_expression_set_constant(const KiwiExpression* expr, double constant, KiwiExpression* out); ]]) else ffi.cdef([[ @@ -296,20 +298,29 @@ else end end ----@param expr kiwi.Expression ----@param var kiwi.Var ----@param coeff number? ----@nodiscard -local function add_expr_term(expr, var, coeff) - local ret = ffi_new(Expression, expr.term_count + 1) --[[@as kiwi.Expression]] - ffi_copy(ret.terms_, expr.terms_, SIZEOF_TERM * expr.term_count) - local dt = ret.terms_[expr.term_count] - dt.var = var - dt.coefficient = coeff or 1.0 - ret.constant = expr.constant - ret.term_count = expr.term_count + 1 - ljkiwi.kiwi_expression_retain(ret) - return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] +local add_expr_term +if RUST then + ---@param expr kiwi.Expression + ---@param var kiwi.Var + ---@param coeff number? + ---@nodiscard + function add_expr_term(expr, var, coeff) + local ret = ffi_new(Expression, expr.term_count + 1) + ljkiwi.kiwi_expression_add_term(expr, var, coeff or 1.0, ret) + return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] + end +else + function add_expr_term(expr, var, coeff) + local ret = ffi_new(Expression, expr.term_count + 1) --[[@as kiwi.Expression]] + ffi_copy(ret.terms_, expr.terms_, SIZEOF_TERM * expr.term_count) + local dt = ret.terms_[expr.term_count] + dt.var = var + dt.coefficient = coeff or 1.0 + ret.constant = expr.constant + ret.term_count = expr.term_count + 1 + ljkiwi.kiwi_expression_retain(ret) + return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] + end end ---@param constant number @@ -719,34 +730,33 @@ do local b_count = b.term_count local ret = ffi_new(Expression, a_count + b_count) --[[@as kiwi.Expression]] - for i = 0, a_count - 1 do - local dt = ret.terms_[i] --[[@as kiwi.Term]] - local st = a.terms_[i] --[[@as kiwi.Term]] - dt.var = st.var - dt.coefficient = st.coefficient - end - for i = 0, b_count - 1 do - local dt = ret.terms_[a_count + i] --[[@as kiwi.Term]] - local st = b.terms_[i] --[[@as kiwi.Term]] - dt.var = st.var - dt.coefficient = st.coefficient - end + ffi_copy(ret.terms_, a.terms_, SIZEOF_TERM * a_count) + ffi_copy(ret.terms_ + a_count, b.terms_, SIZEOF_TERM * b_count) ret.constant = a.constant + b.constant ret.term_count = a_count + b_count ljkiwi.kiwi_expression_retain(ret) return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] end - ---@param expr kiwi.Expression - ---@param constant number - ---@nodiscard - local function new_expr_constant(expr, constant) - local ret = ffi_new(Expression, expr.term_count) --[[@as kiwi.Expression]] - ffi_copy(ret.terms_, expr.terms_, SIZEOF_TERM * expr.term_count) - ret.constant = constant - ret.term_count = expr.term_count - ljkiwi.kiwi_expression_retain(ret) - return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] + local new_expr_constant + if RUST then + ---@param expr kiwi.Expression + ---@param constant number + ---@nodiscard + function new_expr_constant(expr, constant) + local ret = ffi_new(Expression, expr.term_count) + ljkiwi.kiwi_expression_set_constant(expr, constant, ret) + return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] + end + else + function new_expr_constant(expr, constant) + local ret = ffi_new(Expression, expr.term_count) --[[@as kiwi.Expression]] + ffi_copy(ret.terms_, expr.terms_, SIZEOF_TERM * expr.term_count) + ret.constant = constant + ret.term_count = expr.term_count + ljkiwi.kiwi_expression_retain(ret) + return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]] + end end ---@return number diff --git a/rjkiwi/src/expr.rs b/rjkiwi/src/expr.rs index 440ae75..839b524 100644 --- a/rjkiwi/src/expr.rs +++ b/rjkiwi/src/expr.rs @@ -48,6 +48,61 @@ impl KiwiExpression { } } +#[no_mangle] +pub unsafe extern "C" fn kiwi_expression_add_term( + e: *const KiwiExpressionPtr, + v: *const KiwiVar, + coefficient: c_double, + out: *mut KiwiExpressionPtr, +) { + let Some(e) = KiwiExpression::try_from_raw(e) else { + return; + }; + + let n_terms = (e.terms_.len() + 1).min(c_int::MAX as usize); + + let out = core::slice::from_raw_parts_mut(out as *mut (), n_terms) as *mut [()] + as *mut KiwiExpression; + + (*out).owner = out as *mut c_void; + (*out).term_count = n_terms as c_int; + (*out).constant = e.constant; + + for (o, i) in (*out).terms_.iter_mut().zip(e.terms_.iter()) { + o.var = KiwiVar::retain_raw(i.var); + o.coefficient = i.coefficient; + } + (*out).terms_[n_terms - 1] = KiwiTerm { + var: KiwiVar::retain_raw(v), + coefficient, + }; +} + +#[no_mangle] +pub unsafe extern "C" fn kiwi_expression_set_constant( + e: *const KiwiExpressionPtr, + constant: c_double, + out: *mut KiwiExpressionPtr, +) { + let Some(e) = KiwiExpression::try_from_raw(e) else { + return; + }; + + let n_terms = e.terms_.len(); + + let out = core::slice::from_raw_parts_mut(out as *mut (), n_terms) as *mut [()] + as *mut KiwiExpression; + + (*out).owner = out as *mut c_void; + (*out).term_count = n_terms as c_int; + (*out).constant = constant; + + for (o, i) in (*out).terms_.iter_mut().zip(e.terms_.iter()) { + o.var = KiwiVar::retain_raw(i.var); + o.coefficient = i.coefficient; + } +} + #[no_mangle] pub unsafe extern "C" fn kiwi_expression_retain(e: *mut KiwiExpressionPtr) { let Some(e) = not_null_mut(e) else {