Skip to content

Commit

Permalink
stop cloning the whole list of DType factor when we don't need to
Browse files Browse the repository at this point in the history
  • Loading branch information
irevoire authored and sharkdp committed Dec 27, 2024
1 parent 740c4f3 commit 13a665a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 27 deletions.
6 changes: 4 additions & 2 deletions numbat/src/typechecker/constraints.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::rc::Rc;

use compact_str::{format_compact, CompactString};

use super::substitutions::{ApplySubstitution, Substitution, SubstitutionError};
Expand Down Expand Up @@ -306,9 +308,9 @@ impl Constraint {
Constraint::EqualScalar(d) if d == &DType::scalar() => Some(Satisfied::trivially()),
Constraint::EqualScalar(dtype) => match dtype.split_first_factor() {
Some(((DTypeFactor::TVar(tv), k), rest)) => {
let result = DType::from_factors(
let result = DType::from_factors(Rc::new(
rest.iter().map(|(f, j)| (f.clone(), -j / k)).collect(),
);
));
Some(Satisfied::with_substitution(Substitution::single(
tv.clone(),
Type::Dimension(result),
Expand Down
3 changes: 2 additions & 1 deletion numbat/src/typechecker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub mod type_scheme;

use std::collections::HashMap;
use std::ops::Deref;
use std::rc::Rc;

use crate::arithmetic::Exponent;
use crate::ast::{
Expand Down Expand Up @@ -251,7 +252,7 @@ impl TypeChecker {
.into_factors();

// Replace BaseDimension("D") with TVar("D") for all type parameters
for (factor, _) in &mut factors {
for (factor, _) in Rc::make_mut(&mut factors) {
*factor = match factor {
DTypeFactor::BaseDimension(ref n)
if self
Expand Down
56 changes: 32 additions & 24 deletions numbat/src/typed_ast.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::rc::Rc;

use compact_str::{format_compact, CompactString, ToCompactString};
use indexmap::IndexMap;
use itertools::Itertools;
Expand Down Expand Up @@ -42,26 +44,26 @@ type DtypeFactorPower = (DTypeFactor, Exponent);
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct DType {
// Always in canonical form
factors: Vec<DtypeFactorPower>,
factors: Rc<Vec<DtypeFactorPower>>,
}

impl DType {
pub fn factors(&self) -> &[DtypeFactorPower] {
&self.factors
}

pub fn into_factors(self) -> Vec<DtypeFactorPower> {
pub fn into_factors(self) -> Rc<Vec<DtypeFactorPower>> {
self.factors
}

pub fn from_factors(factors: Vec<DtypeFactorPower>) -> DType {
pub fn from_factors(factors: Rc<Vec<DtypeFactorPower>>) -> DType {
let mut dtype = DType { factors };
dtype.canonicalize();
dtype
}

pub fn scalar() -> DType {
DType::from_factors(vec![])
DType::from_factors(Rc::new(vec![]))
}

pub fn is_scalar(&self) -> bool {
Expand Down Expand Up @@ -100,11 +102,17 @@ impl DType {
}

pub fn from_type_variable(v: TypeVariable) -> DType {
DType::from_factors(vec![(DTypeFactor::TVar(v), Exponent::from_integer(1))])
DType::from_factors(Rc::new(vec![(
DTypeFactor::TVar(v),
Exponent::from_integer(1),
)]))
}

pub fn from_type_parameter(name: CompactString) -> DType {
DType::from_factors(vec![(DTypeFactor::TPar(name), Exponent::from_integer(1))])
DType::from_factors(Rc::new(vec![(
DTypeFactor::TPar(name),
Exponent::from_integer(1),
)]))
}

pub fn deconstruct_as_single_type_variable(&self) -> Option<TypeVariable> {
Expand All @@ -117,22 +125,22 @@ impl DType {
}

pub fn from_tgen(i: usize) -> DType {
DType::from_factors(vec![(
DType::from_factors(Rc::new(vec![(
DTypeFactor::TVar(TypeVariable::Quantified(i)),
Exponent::from_integer(1),
)])
)]))
}

pub fn base_dimension(name: &str) -> DType {
DType::from_factors(vec![(
DType::from_factors(Rc::new(vec![(
DTypeFactor::BaseDimension(name.into()),
Exponent::from_integer(1),
)])
)]))
}

fn canonicalize(&mut self) {
// Move all type-variable and tgen factors to the front, sort by name
self.factors.sort_by(|(f1, _), (f2, _)| match (f1, f2) {
Rc::make_mut(&mut self.factors).sort_by(|(f1, _), (f2, _)| match (f1, f2) {
(DTypeFactor::TVar(v1), DTypeFactor::TVar(v2)) => v1.cmp(v2),
(DTypeFactor::TVar(_), _) => std::cmp::Ordering::Less,

Expand All @@ -146,7 +154,7 @@ impl DType {

// Merge powers of equal factors:
let mut new_factors = Vec::new();
for (f, n) in &self.factors {
for (f, n) in self.factors.iter() {
if let Some((last_f, last_n)) = new_factors.last_mut() {
if f == last_f {
*last_n += n;
Expand All @@ -155,16 +163,16 @@ impl DType {
}
new_factors.push((f.clone(), *n));
}
self.factors = new_factors;

// Remove factors with zero exponent:
self.factors
.retain(|(_, n)| *n != Exponent::from_integer(0));
new_factors.retain(|(_, n)| *n != Exponent::from_integer(0));

self.factors = Rc::new(new_factors);
}

pub fn multiply(&self, other: &DType) -> DType {
let mut factors = self.factors.clone();
factors.extend(other.factors.clone());
Rc::make_mut(&mut factors).extend(other.factors.iter().cloned());
DType::from_factors(factors)
}

Expand All @@ -174,7 +182,7 @@ impl DType {
.iter()
.map(|(f, m)| (f.clone(), n * m))
.collect();
DType::from_factors(factors)
DType::from_factors(Rc::new(factors))
}

pub fn inverse(&self) -> DType {
Expand Down Expand Up @@ -218,7 +226,7 @@ impl DType {
fn instantiate(&self, type_variables: &[TypeVariable]) -> DType {
let mut factors = Vec::new();

for (f, n) in &self.factors {
for (f, n) in self.factors.iter() {
match f {
DTypeFactor::TVar(TypeVariable::Quantified(i)) => {
factors.push((DTypeFactor::TVar(type_variables[*i].clone()), *n));
Expand All @@ -228,12 +236,12 @@ impl DType {
}
}
}
Self::from_factors(factors)
Self::from_factors(Rc::new(factors))
}

pub fn to_base_representation(&self) -> BaseRepresentation {
let mut factors = vec![];
for (f, n) in &self.factors {
for (f, n) in self.factors.iter() {
match f {
DTypeFactor::BaseDimension(name) => {
factors.push(BaseRepresentationFactor(name.clone(), *n));
Expand Down Expand Up @@ -271,7 +279,7 @@ impl From<BaseRepresentation> for DType {
.into_iter()
.map(|BaseRepresentationFactor(name, exp)| (DTypeFactor::BaseDimension(name), exp))
.collect();
DType::from_factors(factors)
DType::from_factors(Rc::new(factors))
}
}

Expand Down Expand Up @@ -729,9 +737,9 @@ impl Statement<'_> {
let mut exponents = vec![];
self.for_all_type_schemes(&mut |type_: &mut TypeScheme| {
if let Type::Dimension(dtype) = type_.unsafe_as_concrete() {
for (factor, exp) in dtype.factors {
if factor == DTypeFactor::TVar(tv.clone()) {
exponents.push(exp)
for (factor, exp) in dtype.factors.iter() {
if factor == &DTypeFactor::TVar(tv.clone()) {
exponents.push(*exp)
}
}
}
Expand Down

0 comments on commit 13a665a

Please sign in to comment.