diff --git a/lib/syntax_tree/mutation_visitor.rb b/lib/syntax_tree/mutation_visitor.rb index 0b4b9357..61e33186 100644 --- a/lib/syntax_tree/mutation_visitor.rb +++ b/lib/syntax_tree/mutation_visitor.rb @@ -67,12 +67,12 @@ def visit_alias(node) # Visit a ARef node. def visit_aref(node) - node.copy(index: visit(node.index)) + node.copy(collection: visit(node.collection), index: visit(node.index)) end # Visit a ARefField node. def visit_aref_field(node) - node.copy(index: visit(node.index)) + node.copy(collection: visit(node.collection), index: visit(node.index)) end # Visit a ArgParen node. @@ -120,17 +120,17 @@ def visit_aryptn(node) # Visit a Assign node. def visit_assign(node) - node.copy(target: visit(node.target)) + node.copy(target: visit(node.target), value: visit(node.value)) end # Visit a Assoc node. def visit_assoc(node) - node.copy + node.copy(key: visit(node.key), value: visit(node.value)) end # Visit a AssocSplat node. def visit_assoc_splat(node) - node.copy + node.copy(value: visit(node.value)) end # Visit a Backref node. @@ -155,12 +155,12 @@ def visit_begin(node) # Visit a PinnedBegin node. def visit_pinned_begin(node) - node.copy + node.copy(statement: visit(node.statement)) end # Visit a Binary node. def visit_binary(node) - node.copy + node.copy(left: visit(node.left), right: visit(node.right)) end # Visit a BlockVar node. @@ -178,6 +178,7 @@ def visit_bodystmt(node) node.copy( statements: visit(node.statements), rescue_clause: visit(node.rescue_clause), + else_keyword: visit(node.else_keyword), else_clause: visit(node.else_clause), ensure_clause: visit(node.ensure_clause) ) @@ -209,7 +210,11 @@ def visit_case(node) # Visit a RAssign node. def visit_rassign(node) - node.copy(operator: visit(node.operator)) + node.copy( + value: visit(node.value), + operator: visit(node.operator), + pattern: visit(node.pattern) + ) end # Visit a ClassDeclaration node. @@ -238,6 +243,7 @@ def visit_command(node) # Visit a CommandCall node. def visit_command_call(node) node.copy( + receiver: visit(node.receiver), operator: node.operator == :"::" ? :"::" : visit(node.operator), message: visit(node.message), arguments: visit(node.arguments), @@ -257,12 +263,12 @@ def visit_const(node) # Visit a ConstPathField node. def visit_const_path_field(node) - node.copy(constant: visit(node.constant)) + node.copy(parent: visit(node.parent), constant: visit(node.constant)) end # Visit a ConstPathRef node. def visit_const_path_ref(node) - node.copy(constant: visit(node.constant)) + node.copy(parent: visit(node.parent), constant: visit(node.constant)) end # Visit a ConstRef node. @@ -288,7 +294,7 @@ def visit_def(node) # Visit a Defined node. def visit_defined(node) - node.copy + node.copy(value: visit(node.value)) end # Visit a Block node. @@ -325,6 +331,7 @@ def visit_else(node) # Visit a Elsif node. def visit_elsif(node) node.copy( + predicate: visit(node.predicate), statements: visit(node.statements), consequent: visit(node.consequent) ) @@ -366,6 +373,7 @@ def visit_excessed_comma(node) # Visit a Field node. def visit_field(node) node.copy( + parent: visit(node.parent), operator: node.operator == :"::" ? :"::" : visit(node.operator), name: visit(node.name) ) @@ -388,7 +396,11 @@ def visit_fndptn(node) # Visit a For node. def visit_for(node) - node.copy(index: visit(node.index), statements: visit(node.statements)) + node.copy( + index: visit(node.index), + collection: visit(node.collection), + statements: visit(node.statements) + ) end # Visit a GVar node. @@ -446,7 +458,11 @@ def visit_if(node) # Visit a IfOp node. def visit_if_op(node) - node.copy + node.copy( + predicate: visit(node.predicate), + truthy: visit(node.truthy), + falsy: visit(node.falsy) + ) end # Visit a Imaginary node. @@ -457,6 +473,7 @@ def visit_imaginary(node) # Visit a In node. def visit_in(node) node.copy( + pattern: visit(node.pattern), statements: visit(node.statements), consequent: visit(node.consequent) ) @@ -522,7 +539,7 @@ def visit_lparen(node) # Visit a MAssign node. def visit_massign(node) - node.copy(target: visit(node.target)) + node.copy(target: visit(node.target), value: visit(node.value)) end # Visit a MethodAddBlock node. @@ -565,7 +582,11 @@ def visit_op(node) # Visit a OpAssign node. def visit_opassign(node) - node.copy(target: visit(node.target), operator: visit(node.operator)) + node.copy( + target: visit(node.target), + operator: visit(node.operator), + value: visit(node.value) + ) end # Visit a Params node. @@ -667,7 +688,10 @@ def visit_regexp_literal(node) # Visit a RescueEx node. def visit_rescue_ex(node) - node.copy(variable: visit(node.variable)) + node.copy( + exceptions: visit(node.exceptions), + variable: visit(node.variable) + ) end # Visit a Rescue node. @@ -682,7 +706,7 @@ def visit_rescue(node) # Visit a RescueMod node. def visit_rescue_mod(node) - node.copy + node.copy(statement: visit(node.statement), value: visit(node.value)) end # Visit a RestParam node. @@ -707,7 +731,7 @@ def visit_rparen(node) # Visit a SClass node. def visit_sclass(node) - node.copy(bodystmt: visit(node.bodystmt)) + node.copy(target: visit(node.target), bodystmt: visit(node.bodystmt)) end # Visit a Statements node. @@ -815,7 +839,7 @@ def visit_not(node) # Visit a Unary node. def visit_unary(node) - node.copy + node.copy(statement: visit(node.statement)) end # Visit a Undef node. @@ -842,7 +866,7 @@ def visit_until(node) # Visit a VarField node. def visit_var_field(node) - node.copy(value: visit(node.value)) + node.copy(value: node.value == :nil ? :nil : visit(node.value)) end # Visit a VarRef node. diff --git a/test/mutation_test.rb b/test/mutation_test.rb index ab9dd019..f97b3ddc 100644 --- a/test/mutation_test.rb +++ b/test/mutation_test.rb @@ -21,6 +21,26 @@ def test_mutates_based_on_patterns assert_equal(expected, SyntaxTree::Formatter.format(source, program)) end + def test_deep_mutation + source = <<~RUBY + hash = { "key" => a ? foo : nil } + RUBY + + expected = <<~RUBY + hash = { "key" => a ? bar : nil } + RUBY + + rename_foo_into_bar = + SyntaxTree.mutation do |mutation| + mutation.mutate("Ident[value: 'foo']") do |node| + node.copy(value: "bar") + end + end + + program = SyntaxTree.parse(source).accept(rename_foo_into_bar) + assert_equal(expected, SyntaxTree::Formatter.format(source, program)) + end + private def build_mutation