Skip to content

Commit

Permalink
Refactor Detectors: remove unnecessary AST visitors (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexroan authored Nov 7, 2023
1 parent b457820 commit 8f4ecae
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 140 deletions.
59 changes: 25 additions & 34 deletions src/detect/low/avoid_abi_encode_packed.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::error::Error;

use crate::{
ast::MemberAccess,
context::loader::{ASTNode, ContextLoader},
detect::detector::{Detector, IssueSeverity},
visitor::ast_visitor::{ASTConstVisitor, Node},
};
use eyre::Result;

Expand All @@ -13,43 +11,36 @@ pub struct AvoidAbiEncodePackedDetector {
found_abi_encode_packed: Vec<Option<ASTNode>>,
}

impl ASTConstVisitor for AvoidAbiEncodePackedDetector {
fn visit_member_access(&mut self, node: &MemberAccess) -> Result<bool> {
// If the node's member_name = "encodePacked", loop through the argument_types and count how many of them contain any of the following in type_strings:
// ["bytes ", "[]", "string"]
// If the count is greater than 1, add the node to the found_abi_encode_packed vector
if node.member_name == "encodePacked" {
let mut count = 0;
let argument_types = node.argument_types.as_ref().unwrap();
for argument_type in argument_types {
if argument_type
.type_string
.as_ref()
.unwrap()
.contains("bytes ")
|| argument_type.type_string.as_ref().unwrap().contains("[]")
|| argument_type
impl Detector for AvoidAbiEncodePackedDetector {
fn detect(&mut self, loader: &ContextLoader) -> Result<bool, Box<dyn Error>> {
for member_access in loader.get_member_accesses() {
// If the member_access's member_name = "encodePacked", loop through the argument_types and count how many of them contain any of the following in type_strings:
// ["bytes ", "[]", "string"]
// If the count is greater than 1, add the member_access to the found_abi_encode_packed vector
if member_access.member_name == "encodePacked" {
let mut count = 0;
let argument_types = member_access.argument_types.as_ref().unwrap();
for argument_type in argument_types {
if argument_type
.type_string
.as_ref()
.unwrap()
.contains("string")
{
count += 1;
.contains("bytes ")
|| argument_type.type_string.as_ref().unwrap().contains("[]")
|| argument_type
.type_string
.as_ref()
.unwrap()
.contains("string")
{
count += 1;
}
}
if count > 1 {
self.found_abi_encode_packed
.push(Some(ASTNode::MemberAccess(member_access.clone())));
}
}
if count > 1 {
self.found_abi_encode_packed
.push(Some(ASTNode::MemberAccess(node.clone())));
}
}
Ok(true)
}
}

impl Detector for AvoidAbiEncodePackedDetector {
fn detect(&mut self, loader: &ContextLoader) -> Result<bool, Box<dyn Error>> {
for member_access in loader.get_member_accesses() {
member_access.accept(self)?;
}
Ok(!self.found_abi_encode_packed.is_empty())
}
Expand Down
30 changes: 6 additions & 24 deletions src/detect/low/deprecated_oz_functions.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::error::Error;

use crate::{
ast::{Identifier, MemberAccess},
context::loader::{ASTNode, ContextLoader},
detect::detector::{Detector, IssueSeverity},
visitor::ast_visitor::{ASTConstVisitor, Node},
};
use eyre::Result;

Expand All @@ -13,24 +11,6 @@ pub struct DeprecatedOZFunctionsDetector {
found_deprecated_oz_functions: Vec<Option<ASTNode>>,
}

impl ASTConstVisitor for DeprecatedOZFunctionsDetector {
fn visit_identifier(&mut self, node: &Identifier) -> Result<bool> {
if node.name == "_setupRole" {
self.found_deprecated_oz_functions
.push(Some(ASTNode::Identifier(node.clone())));
}
Ok(true)
}

fn visit_member_access(&mut self, node: &MemberAccess) -> Result<bool> {
if node.member_name == "safeApprove" {
self.found_deprecated_oz_functions
.push(Some(ASTNode::MemberAccess(node.clone())));
}
Ok(true)
}
}

impl Detector for DeprecatedOZFunctionsDetector {
fn detect(&mut self, loader: &ContextLoader) -> Result<bool, Box<dyn Error>> {
for identifier in loader.get_identifiers() {
Expand All @@ -46,8 +26,9 @@ impl Detector for DeprecatedOZFunctionsDetector {
.absolute_path
.as_ref()
.map_or(false, |path| path.contains("openzeppelin"))
}) {
identifier.accept(self)?;
}) && identifier.name == "_setupRole" {
self.found_deprecated_oz_functions
.push(Some(ASTNode::Identifier(identifier.clone())));
}
}
for member_access in loader.get_member_accesses() {
Expand All @@ -62,8 +43,9 @@ impl Detector for DeprecatedOZFunctionsDetector {
.absolute_path
.as_ref()
.map_or(false, |path| path.contains("openzeppelin"))
}) {
member_access.accept(self)?;
}) && member_access.member_name == "safeApprove" {
self.found_deprecated_oz_functions
.push(Some(ASTNode::MemberAccess(member_access.clone())));
}
}
Ok(!self.found_deprecated_oz_functions.is_empty())
Expand Down
17 changes: 4 additions & 13 deletions src/detect/low/ecrecover.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::error::Error;

use crate::{
ast::Identifier,
context::loader::{ASTNode, ContextLoader},
detect::detector::{Detector, IssueSeverity},
visitor::ast_visitor::{ASTConstVisitor, Node},
};
use eyre::Result;

Expand All @@ -13,20 +11,13 @@ pub struct EcrecoverDetector {
found_ecrecover: Vec<Option<ASTNode>>,
}

impl ASTConstVisitor for EcrecoverDetector {
fn visit_identifier(&mut self, node: &Identifier) -> Result<bool> {
if node.name == "ecrecover" {
self.found_ecrecover
.push(Some(ASTNode::Identifier(node.clone())));
}
Ok(true)
}
}

impl Detector for EcrecoverDetector {
fn detect(&mut self, loader: &ContextLoader) -> Result<bool, Box<dyn Error>> {
for identifier in loader.get_identifiers() {
identifier.accept(self)?;
if identifier.name == "ecrecover" {
self.found_ecrecover
.push(Some(ASTNode::Identifier(identifier.clone())));
}
}
Ok(!self.found_ecrecover.is_empty())
}
Expand Down
23 changes: 7 additions & 16 deletions src/detect/low/unsafe_erc20_functions.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::error::Error;

use crate::{
ast::MemberAccess,
context::loader::{ASTNode, ContextLoader},
detect::detector::{Detector, IssueSeverity},
visitor::ast_visitor::{ASTConstVisitor, Node},
};
use eyre::Result;

Expand All @@ -13,23 +11,16 @@ pub struct UnsafeERC20FunctionsDetector {
found_unsafe_erc20_functions: Vec<Option<ASTNode>>,
}

impl ASTConstVisitor for UnsafeERC20FunctionsDetector {
fn visit_member_access(&mut self, node: &MemberAccess) -> Result<bool> {
if node.member_name == "transferFrom"
|| node.member_name == "approve"
|| node.member_name == "transfer"
{
self.found_unsafe_erc20_functions
.push(Some(ASTNode::MemberAccess(node.clone())));
}
Ok(true)
}
}

impl Detector for UnsafeERC20FunctionsDetector {
fn detect(&mut self, loader: &ContextLoader) -> Result<bool, Box<dyn Error>> {
for member_access in loader.get_member_accesses() {
member_access.accept(self)?;
if member_access.member_name == "transferFrom"
|| member_access.member_name == "approve"
|| member_access.member_name == "transfer"
{
self.found_unsafe_erc20_functions
.push(Some(ASTNode::MemberAccess(member_access.clone())));
}
}
Ok(!self.found_unsafe_erc20_functions.is_empty())
}
Expand Down
23 changes: 7 additions & 16 deletions src/detect/low/unspecific_solidity_pragma.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
use std::error::Error;

use crate::{
ast::PragmaDirective,
context::loader::{ASTNode, ContextLoader},
detect::detector::{Detector, IssueSeverity},
visitor::ast_visitor::{ASTConstVisitor, Node},
};
use eyre::Result;

Expand All @@ -13,23 +11,16 @@ pub struct UnspecificSolidityPragmaDetector {
found_unspecific_solidity_pragma: Vec<Option<ASTNode>>,
}

impl ASTConstVisitor for UnspecificSolidityPragmaDetector {
fn visit_pragma_directive(&mut self, node: &PragmaDirective) -> Result<bool> {
for literal in &node.literals {
if literal.contains('^') || literal.contains('>') {
self.found_unspecific_solidity_pragma
.push(Some(ASTNode::PragmaDirective(node.clone())));
break;
}
}
Ok(true)
}
}

impl Detector for UnspecificSolidityPragmaDetector {
fn detect(&mut self, loader: &ContextLoader) -> Result<bool, Box<dyn Error>> {
for pragma_directive in loader.get_pragma_directives() {
pragma_directive.accept(self)?;
for literal in &pragma_directive.literals {
if literal.contains('^') || literal.contains('>') {
self.found_unspecific_solidity_pragma
.push(Some(ASTNode::PragmaDirective(pragma_directive.clone())));
break;
}
}
}
Ok(!self.found_unspecific_solidity_pragma.is_empty())
}
Expand Down
61 changes: 24 additions & 37 deletions src/detect/medium/solmate_safe_transfer_lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
use std::error::Error;

use crate::ast::ImportDirective;
use crate::visitor::ast_visitor::Node;
use crate::{
ast::MemberAccess,
context::loader::{ASTNode, ContextLoader},
detect::detector::{Detector, IssueSeverity},
visitor::ast_visitor::ASTConstVisitor,
};
use eyre::Result;

Expand All @@ -16,46 +12,37 @@ pub struct SolmateSafeTransferLibDetector {
found_transfer_usage: Vec<Option<ASTNode>>,
}

impl ASTConstVisitor for SolmateSafeTransferLibDetector {
fn visit_import_directive(&mut self, node: &ImportDirective) -> Result<bool> {
if !self.found_solmate_import {
// If the import directive absolute_path contains the strings "solmate" and "SafeTransferLib", flip the found_solmate_import flag to true
if node.absolute_path.as_ref().unwrap().contains("solmate")
&& node
impl Detector for SolmateSafeTransferLibDetector {
fn detect(&mut self, loader: &ContextLoader) -> Result<bool, Box<dyn Error>> {
for import_directive in loader.get_import_directives() {
if !self.found_solmate_import {
// If the import directive absolute_path contains the strings "solmate" and "SafeTransferLib", flip the found_solmate_import flag to true
if import_directive
.absolute_path
.as_ref()
.unwrap()
.contains("SafeTransferLib")
{
self.found_solmate_import = true;
.contains("solmate")
&& import_directive
.absolute_path
.as_ref()
.unwrap()
.contains("SafeTransferLib")
{
self.found_solmate_import = true;
}
}
}
Ok(true)
}

fn visit_member_access(&mut self, node: &MemberAccess) -> Result<bool> {
// If the member access member_name is any of the following names, add it to the list of found
// found_transfer_usage vector: ["safeTransfer", "safeTransferFrom", "safeApprove"]
if node.member_name == "safeTransfer"
|| node.member_name == "safeTransferFrom"
|| node.member_name == "safeApprove"
{
self.found_transfer_usage
.push(Some(ASTNode::MemberAccess(node.clone())));
}

Ok(true)
}
}

impl Detector for SolmateSafeTransferLibDetector {
fn detect(&mut self, loader: &ContextLoader) -> Result<bool, Box<dyn Error>> {
for import_directive in loader.get_import_directives() {
import_directive.accept(self)?;
}

for member_access in loader.get_member_accesses() {
member_access.accept(self)?;
// If the member access member_name is any of the following names, add it to the list of found
// found_transfer_usage vector: ["safeTransfer", "safeTransferFrom", "safeApprove"]
if member_access.member_name == "safeTransfer"
|| member_access.member_name == "safeTransferFrom"
|| member_access.member_name == "safeApprove"
{
self.found_transfer_usage
.push(Some(ASTNode::MemberAccess(member_access.clone())));
}
}

if self.found_solmate_import && !self.found_transfer_usage.is_empty() {
Expand Down

0 comments on commit 8f4ecae

Please sign in to comment.