Skip to content

Commit

Permalink
Add tests for nbt tags
Browse files Browse the repository at this point in the history
  • Loading branch information
DonoA committed Jan 7, 2019
1 parent e026905 commit 9a58b04
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 41 deletions.
71 changes: 55 additions & 16 deletions pyanvil/nbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,31 @@ def register_parser(id, clazz):

_parsers[id] = clazz

def create_simple_nbt_class(tag_id, tag_name, tag_width, tag_parser):
def create_simple_nbt_class(tag_id, class_tag_name, tag_width, tag_parser):

class DataNBTTag:

clazz_width = tag_width
clazz_name = tag_name
clazz_name = class_tag_name
clazz_parser = tag_parser
clazz_id = tag_id

@classmethod
def parse(cls, stream, name):
return cls(
name,
struct.unpack(
tag_value=struct.unpack(
cls.clazz_parser,
stream.read(cls.clazz_width)
)[0]
)[0],
tag_name=name
)

def __init__(self, tag_name, tag_value):
def __init__(self, tag_value, tag_name='None'):
self.tag_name = tag_name
self.tag_value = tag_value

def print(self, indent=''):
print(indent + type(self).clazz_name + ': ' + self.tag_name + ' = ' + str(self.tag_value))
print(indent + self.__repr__())

def get(self):
return self.tag_value
Expand All @@ -53,7 +53,7 @@ def clone(self):
return type(self)(self.tag_name, self.tag_value)

def __repr__(self):
return f'{type(self).clazz_name}Tag(\'{self.tag_name}\', {str(self.tag_value)})'
return f'{type(self).clazz_name}Tag \'{self.tag_name}\' = {str(self.tag_value)}'

def __eq__(self, other):
return self.tag_name == other.tag_name and self.tag_value == other.tag_value
Expand All @@ -71,9 +71,9 @@ class DataNBTTag:
def parse(cls, stream, name):
payload_length = int.from_bytes(stream.read(2), byteorder='big', signed=False)
payload = stream.read(payload_length).decode('utf-8')
return cls(name, payload)
return cls(payload, tag_name=name)

def __init__(self, tag_name, tag_value):
def __init__(self, tag_value, tag_name='None'):
self.tag_name = tag_name
self.tag_value = tag_value

Expand All @@ -98,6 +98,12 @@ def serialize(self, stream, include_name=True):
def clone(self):
return type(self)(self.tag_name, self.tag_value)

def __repr__(self):
return f'StringTag: {self.tag_name} = \'{self.tag_value}\''

def __eq__(self, other):
return self.tag_name == other.tag_name and self.tag_value == other.tag_value

register_parser(tag_id, DataNBTTag)

return DataNBTTag
Expand All @@ -112,12 +118,12 @@ class ArrayNBTTag:
@classmethod
def parse(cls, stream, name):
payload_length = int.from_bytes(stream.read(4), byteorder='big', signed=True)
tag = cls(name)
tag = cls(tag_name=name)
for i in range(payload_length):
tag.add_child(cls.clazz_sub_type.parse(stream, 'None'))
return tag

def __init__(self, tag_name, children=[]):
def __init__(self, tag_name='None', children=[]):
self.tag_name = tag_name
self.children = children[:]

Expand Down Expand Up @@ -147,6 +153,15 @@ def serialize(self, stream, include_name=True):
def clone(self):
return type(self)(self.tag_name, children=[c.clone() for c in self.children])

def __repr__(self):
str_dat = ', '.join([str(c.get()) for c in self.children])
return f'{type(self).clazz_name}: {self.tag_name} size {str(len(self.children))} = [{str_dat}]'

def __eq__(self, other):
return self.tag_name == other.tag_name and \
len(self.children) == len(other.children) and \
not any([not self.children[i] == other.children[i] for i in range(len(self.children))])

register_parser(tag_id, ArrayNBTTag)

return ArrayNBTTag
Expand All @@ -162,12 +177,12 @@ def parse(cls, stream, name):

sub_type = int.from_bytes(stream.read(1), byteorder='big', signed=False)
payload_length = int.from_bytes(stream.read(4), byteorder='big', signed=True)
tag = cls(name, sub_type)
tag = cls(sub_type, tag_name=name)
for i in range(payload_length):
tag.add_child(_parsers[sub_type].parse(stream, 'None'))
return tag

def __init__(self, tag_name, sub_type_id, children=[]):
def __init__(self, sub_type_id, tag_name='None', children=[]):
self.tag_name = tag_name
self.sub_type_id = sub_type_id
self.children = children[:]
Expand Down Expand Up @@ -197,6 +212,15 @@ def serialize(self, stream, include_name=True):
def clone(self):
return type(self)(self.tag_name, self.sub_type_id, children=[c.clone() for c in self.children])

def __repr__(self):
str_dat = ', '.join([c.__repr__() for c in self.children])
return f'ListTag: {self.tag_name} size {str(len(self.children))} = [{str_dat}]'

def __eq__(self, other):
return self.tag_name == other.tag_name and \
len(self.children) == len(other.children) and \
(len(self.children) == 0 or not any([not self.children[i] == other.children[i] for i in range(len(self.children))]))

register_parser(tag_id, ListNBTTag)

return ListNBTTag
Expand All @@ -208,13 +232,13 @@ class CompundNBTTag:

@classmethod
def parse(cls, stream, name):
tag = cls(name)
tag = cls(tag_name=name)
while stream.peek() != 0: # end tag
tag.add_child(parse_nbt(stream))
stream.read(1) # get rid of the end tag
return tag

def __init__(self, tag_name, children=[]):
def __init__(self, tag_name='None', children=[]):
self.tag_name = tag_name
self.children = { c.tag_name: c for c in children[:] }

Expand Down Expand Up @@ -254,6 +278,21 @@ def serialize(self, stream, include_name=True):
def clone(self):
return type(self)(self.tag_name, children=[v.clone() for k, v in self.children.items()])

def __repr__(self):
str_dat = ', '.join([c.__repr__() for name, c in self.children.items()])
return f'CompundTag: {self.tag_name} size {str(len(self.children))} = {{{str_dat}}}]'

def __eq__(self, other):
passed = True
for name, v in self.children.items():
if name not in other.children:
passed = False
elif other.children[name] != v:
passed = False
return self.tag_name == other.tag_name and \
len(self.children) == len(other.children) and \
passed

register_parser(tag_id, CompundNBTTag)

return CompundNBTTag
Expand Down
24 changes: 12 additions & 12 deletions pyanvil/world.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,37 +61,37 @@ def serialize(self):
if dirty:
self.palette = list(set([ b._state for b in self.blocks ] + [ BlockState('minecraft:air', {}) ]))
self.palette.sort(key=lambda s: s.name)
serial_section.add_child(nbt.ByteTag('Y', self.y_index))
serial_section.add_child(nbt.ByteTag(self.y_index, tag_name='Y'))
mat_id_mapping = {self.palette[i]: i for i in range(len(self.palette))}
new_palette = self._serialize_palette()
serial_section.add_child(new_palette)
serial_section.add_child(self._serialize_blockstates(mat_id_mapping))

if not serial_section.has('SkyLight'):
serial_section.add_child(nbt.ByteArrayTag('SkyLight', [nbt.ByteTag('None', -1) for i in range(2048)]))
serial_section.add_child(nbt.ByteArrayTag(tag_name='SkyLight', children=[nbt.ByteTag('None', -1) for i in range(2048)]))

if not serial_section.has('BlockLight'):
serial_section.add_child(nbt.ByteArrayTag('BlockLight', [nbt.ByteTag('None', -1) for i in range(2048)]))
serial_section.add_child(nbt.ByteArrayTag(tag_name='BlockLight', children=[nbt.ByteTag('None', -1) for i in range(2048)]))

return serial_section

def _serialize_palette(self):
serial_palette = nbt.ListTag('Palette', nbt.CompoundTag.clazz_id)
serial_palette = nbt.ListTag(nbt.CompoundTag.clazz_id, tag_name='Palette')
for state in self.palette:
palette_item = nbt.CompoundTag('None', children=[
nbt.StringTag('Name', state.name)
palette_item = nbt.CompoundTag(tag_name='None', children=[
nbt.StringTag(state.name, tag_name='Name')
])
if len(state.props) != 0:
serial_props = nbt.CompoundTag('Properties')
serial_props = nbt.CompoundTag(tag_name='Properties')
for name, val in state.props.items():
serial_props.add_child(nbt.StringTag(name, str(val)))
serial_props.add_child(nbt.StringTag(str(val), tag_name=name))
palette_item.add_child(serial_props)
serial_palette.add_child(palette_item)

return serial_palette

def _serialize_blockstates(self, state_mapping):
serial_states = nbt.LongArrayTag('BlockStates')
serial_states = nbt.LongArrayTag(tag_name='BlockStates')
width = math.ceil(math.log(len(self.palette), 2))
if width < 4:
width = 4
Expand All @@ -103,7 +103,7 @@ def _serialize_blockstates(self, state_mapping):
for i in range(int((len(self.blocks) * width)/64)):
lng = data & mask
lng = int.from_bytes(lng.to_bytes(8, byteorder='big', signed=False), byteorder='big', signed=True)
serial_states.add_child(nbt.LongTag('', lng))
serial_states.add_child(nbt.LongTag(lng))
data = data >> 64
return serial_states

Expand All @@ -125,7 +125,7 @@ def get_section(self, y):
if key not in self.sections:
self.sections[key] = ChunkSection(
[Block(BlockState('minecraft:air', {}), 0, 0, dirty=True) for i in range(4096)],
nbt.CompoundTag('None'),
nbt.CompoundTag(),
key
)
return self.sections[key]
Expand Down Expand Up @@ -180,7 +180,7 @@ def _divide_nibbles(arry):
return rtn

def pack(self):
new_sections = nbt.ListTag('Sections', nbt.CompoundTag.clazz_id, children=[
new_sections = nbt.ListTag(nbt.CompoundTag.clazz_id, tag_name='Sections', children=[
self.sections[sec].serialize() for sec in self.sections
])
new_nbt = self.raw_nbt.clone()
Expand Down
Loading

0 comments on commit 9a58b04

Please sign in to comment.