diff --git a/numbat/src/typechecker/constraints.rs b/numbat/src/typechecker/constraints.rs index 34e6cbeb..e48b6bf2 100644 --- a/numbat/src/typechecker/constraints.rs +++ b/numbat/src/typechecker/constraints.rs @@ -1,4 +1,4 @@ -use std::rc::Rc; +use std::sync::Arc; use compact_str::{format_compact, CompactString}; @@ -308,7 +308,7 @@ impl Constraint { Constraint::EqualScalar(d) if d == &DType::scalar() => Some(Satisfied::trivially()), Constraint::EqualScalar(dtype) => match dtype.split_first_factor() { Some(((DTypeFactor::TVar(tv), k), rest)) => { - let result = DType::from_factors(Rc::new( + let result = DType::from_factors(Arc::new( rest.iter().map(|(f, j)| (f.clone(), -j / k)).collect(), )); Some(Satisfied::with_substitution(Substitution::single( diff --git a/numbat/src/typechecker/mod.rs b/numbat/src/typechecker/mod.rs index 4f553cf2..a34a4308 100644 --- a/numbat/src/typechecker/mod.rs +++ b/numbat/src/typechecker/mod.rs @@ -14,7 +14,7 @@ pub mod type_scheme; use std::collections::HashMap; use std::ops::Deref; -use std::rc::Rc; +use std::sync::Arc; use crate::arithmetic::Exponent; use crate::ast::{ @@ -252,7 +252,7 @@ impl TypeChecker { .into_factors(); // Replace BaseDimension("D") with TVar("D") for all type parameters - for (factor, _) in Rc::make_mut(&mut factors) { + for (factor, _) in Arc::make_mut(&mut factors) { *factor = match factor { DTypeFactor::BaseDimension(ref n) if self diff --git a/numbat/src/typed_ast.rs b/numbat/src/typed_ast.rs index cb8770a5..a242b455 100644 --- a/numbat/src/typed_ast.rs +++ b/numbat/src/typed_ast.rs @@ -1,4 +1,4 @@ -use std::rc::Rc; +use std::sync::Arc; use compact_str::{format_compact, CompactString, ToCompactString}; use indexmap::IndexMap; @@ -44,7 +44,7 @@ type DtypeFactorPower = (DTypeFactor, Exponent); #[derive(Clone, Debug, PartialEq, Eq)] pub struct DType { // Always in canonical form - factors: Rc>, + factors: Arc>, } impl DType { @@ -52,18 +52,18 @@ impl DType { &self.factors } - pub fn into_factors(self) -> Rc> { + pub fn into_factors(self) -> Arc> { self.factors } - pub fn from_factors(factors: Rc>) -> DType { + pub fn from_factors(factors: Arc>) -> DType { let mut dtype = DType { factors }; dtype.canonicalize(); dtype } pub fn scalar() -> DType { - DType::from_factors(Rc::new(vec![])) + DType::from_factors(Arc::new(vec![])) } pub fn is_scalar(&self) -> bool { @@ -102,14 +102,14 @@ impl DType { } pub fn from_type_variable(v: TypeVariable) -> DType { - DType::from_factors(Rc::new(vec![( + DType::from_factors(Arc::new(vec![( DTypeFactor::TVar(v), Exponent::from_integer(1), )])) } pub fn from_type_parameter(name: CompactString) -> DType { - DType::from_factors(Rc::new(vec![( + DType::from_factors(Arc::new(vec![( DTypeFactor::TPar(name), Exponent::from_integer(1), )])) @@ -125,14 +125,14 @@ impl DType { } pub fn from_tgen(i: usize) -> DType { - DType::from_factors(Rc::new(vec![( + DType::from_factors(Arc::new(vec![( DTypeFactor::TVar(TypeVariable::Quantified(i)), Exponent::from_integer(1), )])) } pub fn base_dimension(name: &str) -> DType { - DType::from_factors(Rc::new(vec![( + DType::from_factors(Arc::new(vec![( DTypeFactor::BaseDimension(name.into()), Exponent::from_integer(1), )])) @@ -140,7 +140,7 @@ impl DType { fn canonicalize(&mut self) { // Move all type-variable and tgen factors to the front, sort by name - Rc::make_mut(&mut self.factors).sort_by(|(f1, _), (f2, _)| match (f1, f2) { + Arc::make_mut(&mut self.factors).sort_by(|(f1, _), (f2, _)| match (f1, f2) { (DTypeFactor::TVar(v1), DTypeFactor::TVar(v2)) => v1.cmp(v2), (DTypeFactor::TVar(_), _) => std::cmp::Ordering::Less, @@ -167,12 +167,12 @@ impl DType { // Remove factors with zero exponent: new_factors.retain(|(_, n)| *n != Exponent::from_integer(0)); - self.factors = Rc::new(new_factors); + self.factors = Arc::new(new_factors); } pub fn multiply(&self, other: &DType) -> DType { let mut factors = self.factors.clone(); - Rc::make_mut(&mut factors).extend(other.factors.iter().cloned()); + Arc::make_mut(&mut factors).extend(other.factors.iter().cloned()); DType::from_factors(factors) } @@ -182,7 +182,7 @@ impl DType { .iter() .map(|(f, m)| (f.clone(), n * m)) .collect(); - DType::from_factors(Rc::new(factors)) + DType::from_factors(Arc::new(factors)) } pub fn inverse(&self) -> DType { @@ -236,7 +236,7 @@ impl DType { } } } - Self::from_factors(Rc::new(factors)) + Self::from_factors(Arc::new(factors)) } pub fn to_base_representation(&self) -> BaseRepresentation { @@ -279,7 +279,7 @@ impl From for DType { .into_iter() .map(|BaseRepresentationFactor(name, exp)| (DTypeFactor::BaseDimension(name), exp)) .collect(); - DType::from_factors(Rc::new(factors)) + DType::from_factors(Arc::new(factors)) } }