Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add to_image #5

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions geokube/core/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,12 @@ def to_geojson(self, target=None):
):
# HACK: The case `self.domain.type is None` is included to be able
# to handle undefined domain types temporarily.
result = {"data": []}
if self.time.size == 1:
result = {}
else:
raise NotImplementedError(
f"multiple times are not supported for geojson"
)
cube = (
self
if isinstance(self.domain.crs, GeogCS)
Expand Down Expand Up @@ -436,7 +441,7 @@ def to_geojson(self, target=None):
else:
feature["properties"][field.name] = value_
time_data["features"].append(feature)
result["data"].append(time_data)
result = time_data
else:
raise NotImplementedError(
f"'self.domain.type' is {self.domain.type}, which is currently"
Expand All @@ -449,6 +454,22 @@ def to_geojson(self, target=None):

return result

def to_image(
self,
filepath,
width,
height,
dpi=100,
format='png',
transparent=True,
bgcolor='FFFFFF'
):
# TODO: support multiple fields
if len(self.fields) > 1:
raise ValueError("to_image support only 1 field")
else:
next(iter(self.fields.values())).to_image(filepath, width, height, dpi, format, transparent, bgcolor)

@classmethod
@geokube_logging
def from_xarray(
Expand Down
79 changes: 77 additions & 2 deletions geokube/core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import cartopy.crs as ccrs
import cartopy.feature as cartf
import dask.array as da
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
Expand Down Expand Up @@ -1192,6 +1193,9 @@ def plot(
figsize=None,
robust=None,
aspect=None,
save_path=None,
save_kwargs=None,
clean_image=False,
**kwargs,
):
axis_names = self.domain._axis_to_name
Expand All @@ -1200,6 +1204,10 @@ def plot(
lat = self.coords.get(axis_names.get(AxisType.LATITUDE))
lon = self.coords.get(axis_names.get(AxisType.LONGITUDE))

# NOTE: The argument `save_kwargs` passes the keyword arguments to the
# `savefig` method, see:
# https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html

# Resolving time series and layers because they do not require most of
# processing other plot types do:
if self._domain._type is DomainType.POINTS:
Expand Down Expand Up @@ -1243,6 +1251,13 @@ def plot(
if "row" not in kwargs and "col" not in kwargs:
for line in plot:
line.axes.set_title("Point Time Series")
if clean_image:
plot.axes.set_axis_off()
plot.axes.set_title('')
if save_path:
fig = plot[0].figure
fig.tight_layout()
fig.savefig(save_path, **(save_kwargs or {}))
return plot
if aspect == "profile":
data = self.to_xarray(encoding=False)[self.name]
Expand All @@ -1263,6 +1278,13 @@ def plot(
if "row" not in kwargs and "col" not in kwargs:
for line in plot:
line.axes.set_title("Point Layers")
if clean_image:
plot.axes.set_axis_off()
plot.axes.set_title('')
if save_path:
fig = plot[0].figure
fig.tight_layout()
fig.savefig(save_path, **(save_kwargs or {}))
return plot

# Resolving Cartopy features and gridlines:
Expand Down Expand Up @@ -1442,8 +1464,56 @@ def plot(
ax.set_xticks(x_ticks)
ax.set_yticks(y_ticks)

if clean_image:
if isinstance(plot.axes, np.ndarray):
for axis in plot.axes.flat:
axis.set_axis_off()
axis.set_title('')
else:
plot.axes.set_axis_off()
plot.axes.set_title('')
if save_path:
if hasattr(plot, 'fig'):
fig = plot.fig
elif hasattr(plot, 'figure'):
fig = plot.figure
else:
raise NotImplementedError()
fig.tight_layout()
# fig.tight_layout(pad=0, h_pad=0, w_pad=0)
fig.savefig(save_path, **(save_kwargs or {}))

return plot

def to_image(
self,
filepath,
width,
height,
dpi=100,
format='png',
transparent=True,
bgcolor='FFFFFF'
):
# NOTE: This method assumes default DPI value.
f = self
if self.domain.crs != GeogCS(6371229):
f = self.to_regular()
# dpi = plt.rcParams['figure.dpi']
w, h = width / dpi, height / dpi
f.plot(
figsize=(w, h),
add_colorbar=False,
save_path=filepath,
save_kwargs={
'transparent': transparent,
'pad_inches': 0,
'dpi': dpi,
# 'bbox_inches': [[0, 0], [w, h]]
},
clean_image=True
)

def to_geojson(self, target=None):
self.load()
if self.domain.type is DomainType.POINTS:
Expand Down Expand Up @@ -1471,7 +1541,12 @@ def to_geojson(self, target=None):
):
# HACK: The case `self.domain.type is None` is included to be able
# to handle undefined domain types temporarily.
result = {"data": []}
if self.time.size == 1:
result = {}
else:
raise NotImplementedError(
f"multiple times are not supported for geojson"
)
field = (
self
if isinstance(self.domain.crs, GeogCS)
Expand Down Expand Up @@ -1533,7 +1608,7 @@ def to_geojson(self, target=None):
"properties": {self.name: float(value)},
}
time_data["features"].append(feature)
result["data"].append(time_data)
result = time_data
else:
raise NotImplementedError(
f"'self.domain.type' is {self.domain.type}, which is currently"
Expand Down
Loading