Skip to content

Commit

Permalink
fix the tests by making the DType Sync
Browse files Browse the repository at this point in the history
  • Loading branch information
irevoire authored and sharkdp committed Dec 27, 2024
1 parent 13a665a commit 77decee
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 19 deletions.
4 changes: 2 additions & 2 deletions numbat/src/typechecker/constraints.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::rc::Rc;
use std::sync::Arc;

use compact_str::{format_compact, CompactString};

Expand Down Expand Up @@ -308,7 +308,7 @@ impl Constraint {
Constraint::EqualScalar(d) if d == &DType::scalar() => Some(Satisfied::trivially()),
Constraint::EqualScalar(dtype) => match dtype.split_first_factor() {
Some(((DTypeFactor::TVar(tv), k), rest)) => {
let result = DType::from_factors(Rc::new(
let result = DType::from_factors(Arc::new(
rest.iter().map(|(f, j)| (f.clone(), -j / k)).collect(),
));
Some(Satisfied::with_substitution(Substitution::single(
Expand Down
4 changes: 2 additions & 2 deletions numbat/src/typechecker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub mod type_scheme;

use std::collections::HashMap;
use std::ops::Deref;
use std::rc::Rc;
use std::sync::Arc;

use crate::arithmetic::Exponent;
use crate::ast::{
Expand Down Expand Up @@ -252,7 +252,7 @@ impl TypeChecker {
.into_factors();

// Replace BaseDimension("D") with TVar("D") for all type parameters
for (factor, _) in Rc::make_mut(&mut factors) {
for (factor, _) in Arc::make_mut(&mut factors) {
*factor = match factor {
DTypeFactor::BaseDimension(ref n)
if self
Expand Down
30 changes: 15 additions & 15 deletions numbat/src/typed_ast.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::rc::Rc;
use std::sync::Arc;

use compact_str::{format_compact, CompactString, ToCompactString};
use indexmap::IndexMap;
Expand Down Expand Up @@ -44,26 +44,26 @@ type DtypeFactorPower = (DTypeFactor, Exponent);
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct DType {
// Always in canonical form
factors: Rc<Vec<DtypeFactorPower>>,
factors: Arc<Vec<DtypeFactorPower>>,
}

impl DType {
pub fn factors(&self) -> &[DtypeFactorPower] {
&self.factors
}

pub fn into_factors(self) -> Rc<Vec<DtypeFactorPower>> {
pub fn into_factors(self) -> Arc<Vec<DtypeFactorPower>> {
self.factors
}

pub fn from_factors(factors: Rc<Vec<DtypeFactorPower>>) -> DType {
pub fn from_factors(factors: Arc<Vec<DtypeFactorPower>>) -> DType {
let mut dtype = DType { factors };
dtype.canonicalize();
dtype
}

pub fn scalar() -> DType {
DType::from_factors(Rc::new(vec![]))
DType::from_factors(Arc::new(vec![]))
}

pub fn is_scalar(&self) -> bool {
Expand Down Expand Up @@ -102,14 +102,14 @@ impl DType {
}

pub fn from_type_variable(v: TypeVariable) -> DType {
DType::from_factors(Rc::new(vec![(
DType::from_factors(Arc::new(vec![(
DTypeFactor::TVar(v),
Exponent::from_integer(1),
)]))
}

pub fn from_type_parameter(name: CompactString) -> DType {
DType::from_factors(Rc::new(vec![(
DType::from_factors(Arc::new(vec![(
DTypeFactor::TPar(name),
Exponent::from_integer(1),
)]))
Expand All @@ -125,22 +125,22 @@ impl DType {
}

pub fn from_tgen(i: usize) -> DType {
DType::from_factors(Rc::new(vec![(
DType::from_factors(Arc::new(vec![(
DTypeFactor::TVar(TypeVariable::Quantified(i)),
Exponent::from_integer(1),
)]))
}

pub fn base_dimension(name: &str) -> DType {
DType::from_factors(Rc::new(vec![(
DType::from_factors(Arc::new(vec![(
DTypeFactor::BaseDimension(name.into()),
Exponent::from_integer(1),
)]))
}

fn canonicalize(&mut self) {
// Move all type-variable and tgen factors to the front, sort by name
Rc::make_mut(&mut self.factors).sort_by(|(f1, _), (f2, _)| match (f1, f2) {
Arc::make_mut(&mut self.factors).sort_by(|(f1, _), (f2, _)| match (f1, f2) {
(DTypeFactor::TVar(v1), DTypeFactor::TVar(v2)) => v1.cmp(v2),
(DTypeFactor::TVar(_), _) => std::cmp::Ordering::Less,

Expand All @@ -167,12 +167,12 @@ impl DType {
// Remove factors with zero exponent:
new_factors.retain(|(_, n)| *n != Exponent::from_integer(0));

self.factors = Rc::new(new_factors);
self.factors = Arc::new(new_factors);
}

pub fn multiply(&self, other: &DType) -> DType {
let mut factors = self.factors.clone();
Rc::make_mut(&mut factors).extend(other.factors.iter().cloned());
Arc::make_mut(&mut factors).extend(other.factors.iter().cloned());
DType::from_factors(factors)
}

Expand All @@ -182,7 +182,7 @@ impl DType {
.iter()
.map(|(f, m)| (f.clone(), n * m))
.collect();
DType::from_factors(Rc::new(factors))
DType::from_factors(Arc::new(factors))
}

pub fn inverse(&self) -> DType {
Expand Down Expand Up @@ -236,7 +236,7 @@ impl DType {
}
}
}
Self::from_factors(Rc::new(factors))
Self::from_factors(Arc::new(factors))
}

pub fn to_base_representation(&self) -> BaseRepresentation {
Expand Down Expand Up @@ -279,7 +279,7 @@ impl From<BaseRepresentation> for DType {
.into_iter()
.map(|BaseRepresentationFactor(name, exp)| (DTypeFactor::BaseDimension(name), exp))
.collect();
DType::from_factors(Rc::new(factors))
DType::from_factors(Arc::new(factors))
}
}

Expand Down

0 comments on commit 77decee

Please sign in to comment.