Skip to content

Commit

Permalink
Add some expression processing functions to rust lib
Browse files Browse the repository at this point in the history
  • Loading branch information
jkl1337 committed Mar 9, 2024
1 parent 9b1c296 commit 15e578b
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 36 deletions.
82 changes: 46 additions & 36 deletions kiwi.lua
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ struct KiwiVar {
void kiwi_var_free(KiwiVar* var);
void kiwi_expression_add_term(const KiwiExpression* expr, KiwiVar* var, double coeff, KiwiExpression* out);
void kiwi_expression_set_constant(const KiwiExpression* expr, double constant, KiwiExpression* out);
]])
else
ffi.cdef([[
Expand Down Expand Up @@ -296,20 +298,29 @@ else
end
end

---@param expr kiwi.Expression
---@param var kiwi.Var
---@param coeff number?
---@nodiscard
local function add_expr_term(expr, var, coeff)
local ret = ffi_new(Expression, expr.term_count + 1) --[[@as kiwi.Expression]]
ffi_copy(ret.terms_, expr.terms_, SIZEOF_TERM * expr.term_count)
local dt = ret.terms_[expr.term_count]
dt.var = var
dt.coefficient = coeff or 1.0
ret.constant = expr.constant
ret.term_count = expr.term_count + 1
ljkiwi.kiwi_expression_retain(ret)
return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]]
local add_expr_term
if RUST then
---@param expr kiwi.Expression
---@param var kiwi.Var
---@param coeff number?
---@nodiscard
function add_expr_term(expr, var, coeff)
local ret = ffi_new(Expression, expr.term_count + 1)
ljkiwi.kiwi_expression_add_term(expr, var, coeff or 1.0, ret)
return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]]
end
else
function add_expr_term(expr, var, coeff)
local ret = ffi_new(Expression, expr.term_count + 1) --[[@as kiwi.Expression]]
ffi_copy(ret.terms_, expr.terms_, SIZEOF_TERM * expr.term_count)
local dt = ret.terms_[expr.term_count]
dt.var = var
dt.coefficient = coeff or 1.0
ret.constant = expr.constant
ret.term_count = expr.term_count + 1
ljkiwi.kiwi_expression_retain(ret)
return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]]
end
end

---@param constant number
Expand Down Expand Up @@ -719,34 +730,33 @@ do
local b_count = b.term_count
local ret = ffi_new(Expression, a_count + b_count) --[[@as kiwi.Expression]]

for i = 0, a_count - 1 do
local dt = ret.terms_[i] --[[@as kiwi.Term]]
local st = a.terms_[i] --[[@as kiwi.Term]]
dt.var = st.var
dt.coefficient = st.coefficient
end
for i = 0, b_count - 1 do
local dt = ret.terms_[a_count + i] --[[@as kiwi.Term]]
local st = b.terms_[i] --[[@as kiwi.Term]]
dt.var = st.var
dt.coefficient = st.coefficient
end
ffi_copy(ret.terms_, a.terms_, SIZEOF_TERM * a_count)
ffi_copy(ret.terms_ + a_count, b.terms_, SIZEOF_TERM * b_count)
ret.constant = a.constant + b.constant
ret.term_count = a_count + b_count
ljkiwi.kiwi_expression_retain(ret)
return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]]
end

---@param expr kiwi.Expression
---@param constant number
---@nodiscard
local function new_expr_constant(expr, constant)
local ret = ffi_new(Expression, expr.term_count) --[[@as kiwi.Expression]]
ffi_copy(ret.terms_, expr.terms_, SIZEOF_TERM * expr.term_count)
ret.constant = constant
ret.term_count = expr.term_count
ljkiwi.kiwi_expression_retain(ret)
return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]]
local new_expr_constant
if RUST then
---@param expr kiwi.Expression
---@param constant number
---@nodiscard
function new_expr_constant(expr, constant)
local ret = ffi_new(Expression, expr.term_count)
ljkiwi.kiwi_expression_set_constant(expr, constant, ret)
return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]]
end
else
function new_expr_constant(expr, constant)
local ret = ffi_new(Expression, expr.term_count) --[[@as kiwi.Expression]]
ffi_copy(ret.terms_, expr.terms_, SIZEOF_TERM * expr.term_count)
ret.constant = constant
ret.term_count = expr.term_count
ljkiwi.kiwi_expression_retain(ret)
return ffi_gc(ret, ljkiwi.kiwi_expression_destroy) --[[@as kiwi.Expression]]
end
end

---@return number
Expand Down
55 changes: 55 additions & 0 deletions rjkiwi/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,61 @@ impl KiwiExpression {
}
}

#[no_mangle]
pub unsafe extern "C" fn kiwi_expression_add_term(
e: *const KiwiExpressionPtr,
v: *const KiwiVar,
coefficient: c_double,
out: *mut KiwiExpressionPtr,
) {
let Some(e) = KiwiExpression::try_from_raw(e) else {
return;
};

let n_terms = (e.terms_.len() + 1).min(c_int::MAX as usize);

let out = core::slice::from_raw_parts_mut(out as *mut (), n_terms) as *mut [()]
as *mut KiwiExpression;

(*out).owner = out as *mut c_void;
(*out).term_count = n_terms as c_int;
(*out).constant = e.constant;

for (o, i) in (*out).terms_.iter_mut().zip(e.terms_.iter()) {
o.var = KiwiVar::retain_raw(i.var);
o.coefficient = i.coefficient;
}
(*out).terms_[n_terms - 1] = KiwiTerm {
var: KiwiVar::retain_raw(v),
coefficient,
};
}

#[no_mangle]
pub unsafe extern "C" fn kiwi_expression_set_constant(
e: *const KiwiExpressionPtr,
constant: c_double,
out: *mut KiwiExpressionPtr,
) {
let Some(e) = KiwiExpression::try_from_raw(e) else {
return;
};

let n_terms = e.terms_.len();

let out = core::slice::from_raw_parts_mut(out as *mut (), n_terms) as *mut [()]
as *mut KiwiExpression;

(*out).owner = out as *mut c_void;
(*out).term_count = n_terms as c_int;
(*out).constant = constant;

for (o, i) in (*out).terms_.iter_mut().zip(e.terms_.iter()) {
o.var = KiwiVar::retain_raw(i.var);
o.coefficient = i.coefficient;
}
}

#[no_mangle]
pub unsafe extern "C" fn kiwi_expression_retain(e: *mut KiwiExpressionPtr) {
let Some(e) = not_null_mut(e) else {
Expand Down

0 comments on commit 15e578b

Please sign in to comment.