Skip to content

Commit

Permalink
Use feature to enable function extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
mokeyish committed Dec 18, 2023
1 parent 6b73105 commit 675c790
Show file tree
Hide file tree
Showing 27 changed files with 129 additions and 5 deletions.
41 changes: 40 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,53 @@ description = "An extension library to Candle that provides PyTorch functions no
license = "MIT OR Apache-2.0"
repository = "https://github.com/mokeyish/candle-ext"

[lib]
doctest = false

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[features]

default = []
default = ["all"]

all = [
"chunk",
"cumsum",
"equal",
"eye",
"logical_not",
"logical_or",
"masked_fill",
"outer",
"scaled_dot_product_attention",
"triangular",
"unbind",
]

# functions feature
chunk = ["to_tuple"]
cumsum = []
equal = []
eye = []
full = []
full_like = ["full"]
logical_not = []
logical_or = []
masked_fill = ["full_like"]
outer = []
scaled_dot_product_attention = ["masked_fill", "logical_not", "tril"]
triangular = []
to_tuple = []
tril = ["triangular"]
trilu = ["triangular"]
unbind = ["to_tuple"]

cuda = ["candle-core/cuda"]

# features for unit tests
test_masked_fill = ["masked_fill", "triangular", "logical_not"]


[dependencies]
candle-core = "0.3"
candle-nn = "0.3"
Expand Down
1 change: 1 addition & 0 deletions src/chunk.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "chunk")]
use crate::{
candle::{shape::Dim, Result, Tensor},
TensorVecExt, F,
Expand Down
1 change: 1 addition & 0 deletions src/cumsum.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "cumsum")]
use crate::{
candle::{shape::Dim, Result, Tensor},
F,
Expand Down
1 change: 1 addition & 0 deletions src/equal.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "equal")]
use crate::{
candle::{Result, Tensor},
F,
Expand Down
1 change: 1 addition & 0 deletions src/eye.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "eye")]
use crate::{
candle::{bail, DType, Device, Result, Shape, Tensor},
F,
Expand Down
3 changes: 2 additions & 1 deletion src/full.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "full")]
use crate::{
candle::{DType, Device, Result, Shape, Tensor, WithDType},
F,
Expand All @@ -24,7 +25,7 @@ impl F {
/// F::full(input.shape(), fill_value, dtype=input.dtype(), device=input.device()).
///
/// [https://pytorch.org/docs/stable/generated/torch.full_like.html](https://pytorch.org/docs/stable/generated/torch.full_like.html)
#[inline]
#[cfg(feature = "full_like")]
pub fn full_like<D: WithDType>(input: &Tensor, fill_value: D) -> Result<Tensor> {
F::full(input.shape(), fill_value, input.dtype(), input.device())
}
Expand Down
54 changes: 53 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![allow(unused_imports)]
//! An extension library to [Candle](https://github.com/huggingface/candle) that provides PyTorch functions not currently available in Candle
//!
//! # Examples
Expand Down Expand Up @@ -68,123 +69,162 @@ mod unbind;
pub struct F;

pub trait TensorExt: Sized {
#[cfg(feature = "chunk")]
fn chunk2<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor)>;
#[cfg(feature = "chunk")]
fn chunk3<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor, Tensor)>;
#[cfg(feature = "chunk")]
fn chunk4<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor, Tensor, Tensor)>;
#[cfg(feature = "chunk")]
fn chunk5<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor, Tensor, Tensor, Tensor)>;
#[cfg(feature = "cumsum")]
fn cumsum<D: Dim>(&self, dim: D) -> Result<Tensor>;
#[cfg(feature = "equal")]
fn equal(&self, other: &Tensor) -> Result<bool>;
#[cfg(feature = "eye")]
fn eye<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Tensor>;
#[cfg(feature = "full")]
fn full<S: Into<Shape>, D: WithDType>(
shape: S,
fill_value: D,
dtype: DType,
device: &Device,
) -> Result<Tensor>;
#[cfg(feature = "full_like")]
fn full_like<D: WithDType>(&self, fill_value: D) -> Result<Tensor>;
#[cfg(feature = "logical_not")]
fn logical_not(&self) -> Result<Self>;
#[cfg(feature = "logical_or")]
fn logical_or(&self, other: &Tensor) -> Result<Self>;
#[cfg(feature = "masked_fill")]
fn masked_fill<D: WithDType>(&self, mask: &Tensor, value: D) -> Result<Self>;
#[cfg(feature = "outer")]
fn outer(&self, vec2: &Tensor) -> Result<Self>;
#[cfg(feature = "triangular")]
fn tril(&self, diagonal: isize) -> Result<Self>;
#[cfg(feature = "triangular")]
fn triu(&self, diagonal: isize) -> Result<Self>;
#[cfg(feature = "unbind")]
fn unbind<D: Dim>(&self, dim: D) -> Result<Vec<Tensor>>;
#[cfg(feature = "unbind")]
fn unbind2<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor)>;
#[cfg(feature = "unbind")]
fn unbind3<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor, Tensor)>;
#[cfg(feature = "unbind")]
fn unbind4<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor, Tensor, Tensor)>;
#[cfg(feature = "unbind")]
fn unbind5<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor, Tensor, Tensor, Tensor)>;
}

impl TensorExt for Tensor {
#[cfg(feature = "triangular")]
#[inline]
fn tril(&self, diagonal: isize) -> Result<Self> {
F::tril(self, diagonal)
}

#[cfg(feature = "triangular")]
#[inline]
fn triu(&self, diagonal: isize) -> Result<Self> {
F::triu(self, diagonal)
}

#[cfg(feature = "logical_not")]
#[inline]
fn logical_not(&self) -> Result<Self> {
F::logical_not(self)
}

#[cfg(feature = "logical_or")]
#[inline]
fn logical_or(&self, other: &Tensor) -> Result<Self> {
F::logical_or(self, other)
}

#[cfg(feature = "masked_fill")]
#[inline]
fn masked_fill<D: WithDType>(&self, mask: &Tensor, value: D) -> Result<Self> {
F::masked_fill(self, mask, value)
}

#[cfg(feature = "outer")]
#[inline]
fn outer(&self, vec2: &Tensor) -> Result<Self> {
F::outer(self, vec2)
}

#[cfg(feature = "unbind")]
#[inline]
fn unbind<D: Dim>(&self, dim: D) -> Result<Vec<Tensor>> {
F::unbind(self, dim)
}

#[cfg(feature = "unbind")]
#[inline]
fn unbind2<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor)> {
F::unbind2(self, dim)
}

#[cfg(feature = "unbind")]
#[inline]
fn unbind3<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor, Tensor)> {
F::unbind3(self, dim)
}

#[cfg(feature = "unbind")]
#[inline]
fn unbind4<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor, Tensor, Tensor)> {
F::unbind4(self, dim)
}

#[cfg(feature = "unbind")]
#[inline]
fn unbind5<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor, Tensor, Tensor, Tensor)> {
F::unbind5(self, dim)
}

#[cfg(feature = "equal")]
#[inline]
fn equal(&self, other: &Tensor) -> Result<bool> {
F::equal(self, other)
}

#[cfg(feature = "eye")]
#[inline]
fn eye<S: Into<Shape>>(shape: S, dtype: DType, device: &Device) -> Result<Tensor> {
F::eye(shape, dtype, device)
}

#[cfg(feature = "chunk")]
#[inline]
fn chunk2<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor)> {
F::chunk2(self, dim)
}

#[cfg(feature = "chunk")]
#[inline]
fn chunk3<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor, Tensor)> {
F::chunk3(self, dim)
}

#[cfg(feature = "chunk")]
#[inline]
fn chunk4<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor, Tensor, Tensor)> {
F::chunk4(self, dim)
}

#[cfg(feature = "chunk")]
#[inline]
fn chunk5<D: Dim>(&self, dim: D) -> Result<(Tensor, Tensor, Tensor, Tensor, Tensor)> {
F::chunk5(self, dim)
}

#[cfg(feature = "cumsum")]
#[inline]
fn cumsum<D: Dim>(&self, dim: D) -> Result<Tensor> {
F::cumsum(self, dim)
}

#[cfg(feature = "full")]
#[inline]
fn full<S: Into<Shape>, D: WithDType>(
shape: S,
Expand All @@ -195,32 +235,44 @@ impl TensorExt for Tensor {
F::full(shape, fill_value, dtype, device)
}

#[cfg(feature = "full_like")]
#[inline]
fn full_like<D: WithDType>(&self, fill_value: D) -> Result<Tensor> {
F::full_like(self, fill_value)
}
}

pub trait TensorVecExt {
#[cfg(feature = "to_tuple")]
fn to_tuple2(self) -> Result<(Tensor, Tensor)>;
#[cfg(feature = "to_tuple")]
fn to_tuple3(self) -> Result<(Tensor, Tensor, Tensor)>;
#[cfg(feature = "to_tuple")]
fn to_tuple4(self) -> Result<(Tensor, Tensor, Tensor, Tensor)>;
#[cfg(feature = "to_tuple")]
fn to_tuple5(self) -> Result<(Tensor, Tensor, Tensor, Tensor, Tensor)>;
}

impl TensorVecExt for Vec<Tensor> {
#[cfg(feature = "to_tuple")]
#[inline]
fn to_tuple2(self) -> Result<(Tensor, Tensor)> {
F::to_tuple2(self)
}

#[cfg(feature = "to_tuple")]
#[inline]
fn to_tuple3(self) -> Result<(Tensor, Tensor, Tensor)> {
F::to_tuple3(self)
}

#[cfg(feature = "to_tuple")]
#[inline]
fn to_tuple4(self) -> Result<(Tensor, Tensor, Tensor, Tensor)> {
F::to_tuple4(self)
}

#[cfg(feature = "to_tuple")]
#[inline]
fn to_tuple5(self) -> Result<(Tensor, Tensor, Tensor, Tensor, Tensor)> {
F::to_tuple5(self)
Expand Down
1 change: 1 addition & 0 deletions src/logical_not.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "logical_not")]
use crate::{
candle::{Result, Tensor},
F,
Expand Down
1 change: 1 addition & 0 deletions src/logical_or.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "logical_or")]
use crate::{
candle::{self, CpuStorage, CustomOp2, Layout, Result, Shape, Tensor, WithDType},
F,
Expand Down
1 change: 1 addition & 0 deletions src/masked_fill.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "masked_fill")]
use crate::{
candle::{Result, Tensor, WithDType},
TensorExt, F,
Expand Down
1 change: 1 addition & 0 deletions src/outer.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "outer")]
use crate::{
candle::{Result, Tensor, D},
F,
Expand Down
1 change: 1 addition & 0 deletions src/scaled_dot_product_attention.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "scaled_dot_product_attention")]
use crate::{
candle::{nn::ops, DType, Result, Tensor, D},
TensorExt, F,
Expand Down
1 change: 1 addition & 0 deletions src/to_tuple.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "to_tuple")]
use crate::{
candle::{Error, Result, Tensor},
F,
Expand Down
1 change: 1 addition & 0 deletions src/triangular.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "triangular")]
use crate::{
candle::{Result, Tensor},
F,
Expand Down
1 change: 1 addition & 0 deletions src/unbind.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "unbind")]
use crate::{
candle::{shape::Dim, Error, Result, Tensor},
TensorVecExt, F,
Expand Down
1 change: 1 addition & 0 deletions tests/chunk.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "chunk")]
use candle_ext::{
candle::{Device, Result, Tensor},
TensorExt,
Expand Down
1 change: 1 addition & 0 deletions tests/cumsum.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "cumsum")]
use candle_ext::{
candle::{Device, Result, Tensor},
TensorExt, F,
Expand Down
1 change: 1 addition & 0 deletions tests/equal.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "equal")]
use candle_ext::{
candle::{DType, Device, Result, Tensor},
TensorExt, F,
Expand Down
1 change: 1 addition & 0 deletions tests/eye.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "eye")]
use candle_ext::{
candle::{DType, Device, Result, Tensor},
TensorExt, F,
Expand Down
1 change: 1 addition & 0 deletions tests/full.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "full")]
use candle_ext::{
candle::{Device, Result, Tensor},
F,
Expand Down
13 changes: 11 additions & 2 deletions tests/logical_not.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
#![cfg(feature = "logical_not")]
use candle_ext::{
candle::{DType, Device, Result, Tensor},
candle::{Device, Result, Tensor},
TensorExt,
};

#[test]
fn test_logical_not_1() -> Result<()> {
let device = Device::Cpu;
let a = Tensor::ones((2, 2), DType::U8, &device)?.triu(0)?;

#[rustfmt::skip]
let a = Tensor::new(&[
[1u8, 1],
[0, 1]
],&device)?;

println!("{}", a.logical_not()?);

#[rustfmt::skip]
assert_eq!(a.logical_not()?.to_vec2::<u8>()?, &[
[0, 0],
Expand Down
1 change: 1 addition & 0 deletions tests/logical_or.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![cfg(feature = "logical_or")]
use candle_ext::{
candle::{DType, Device, Result, Tensor},
TensorExt,
Expand Down
Loading

0 comments on commit 675c790

Please sign in to comment.