Skip to content

Commit

Permalink
Merge pull request #7 from CMCC-Foundation/add_method_to_image
Browse files Browse the repository at this point in the history
Add method to image
  • Loading branch information
Marco Mancini authored Jan 30, 2024
2 parents 458664b + 5dd2f2d commit 0526755
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 4 deletions.
25 changes: 23 additions & 2 deletions geokube/core/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,12 @@ def to_geojson(self, target=None):
):
# HACK: The case `self.domain.type is None` is included to be able
# to handle undefined domain types temporarily.
result = {"data": []}
if self.time.size == 1:
result = {}
else:
raise NotImplementedError(
f"multiple times are not supported for geojson"
)
cube = (
self
if isinstance(self.domain.crs, GeogCS)
Expand Down Expand Up @@ -436,7 +441,7 @@ def to_geojson(self, target=None):
else:
feature["properties"][field.name] = value_
time_data["features"].append(feature)
result["data"].append(time_data)
result = time_data
else:
raise NotImplementedError(
f"'self.domain.type' is {self.domain.type}, which is currently"
Expand All @@ -449,6 +454,22 @@ def to_geojson(self, target=None):

return result

def to_image(
self,
filepath,
width,
height,
dpi=100,
format='png',
transparent=True,
bgcolor='FFFFFF'
):
# TODO: support multiple fields
if len(self.fields) > 1:
raise ValueError("to_image support only 1 field")
else:
next(iter(self.fields.values())).to_image(filepath, width, height, dpi, format, transparent, bgcolor)

@classmethod
@geokube_logging
def from_xarray(
Expand Down
79 changes: 77 additions & 2 deletions geokube/core/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import cartopy.crs as ccrs
import cartopy.feature as cartf
import dask.array as da
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
Expand Down Expand Up @@ -1192,6 +1193,9 @@ def plot(
figsize=None,
robust=None,
aspect=None,
save_path=None,
save_kwargs=None,
clean_image=False,
**kwargs,
):
axis_names = self.domain._axis_to_name
Expand All @@ -1200,6 +1204,10 @@ def plot(
lat = self.coords.get(axis_names.get(AxisType.LATITUDE))
lon = self.coords.get(axis_names.get(AxisType.LONGITUDE))

# NOTE: The argument `save_kwargs` passes the keyword arguments to the
# `savefig` method, see:
# https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.savefig.html

# Resolving time series and layers because they do not require most of
# processing other plot types do:
if self._domain._type is DomainType.POINTS:
Expand Down Expand Up @@ -1243,6 +1251,13 @@ def plot(
if "row" not in kwargs and "col" not in kwargs:
for line in plot:
line.axes.set_title("Point Time Series")
if clean_image:
plot.axes.set_axis_off()
plot.axes.set_title('')
if save_path:
fig = plot[0].figure
fig.tight_layout()
fig.savefig(save_path, **(save_kwargs or {}))
return plot
if aspect == "profile":
data = self.to_xarray(encoding=False)[self.name]
Expand All @@ -1263,6 +1278,13 @@ def plot(
if "row" not in kwargs and "col" not in kwargs:
for line in plot:
line.axes.set_title("Point Layers")
if clean_image:
plot.axes.set_axis_off()
plot.axes.set_title('')
if save_path:
fig = plot[0].figure
fig.tight_layout()
fig.savefig(save_path, **(save_kwargs or {}))
return plot

# Resolving Cartopy features and gridlines:
Expand Down Expand Up @@ -1442,8 +1464,56 @@ def plot(
ax.set_xticks(x_ticks)
ax.set_yticks(y_ticks)

if clean_image:
if isinstance(plot.axes, np.ndarray):
for axis in plot.axes.flat:
axis.set_axis_off()
axis.set_title('')
else:
plot.axes.set_axis_off()
plot.axes.set_title('')
if save_path:
if hasattr(plot, 'fig'):
fig = plot.fig
elif hasattr(plot, 'figure'):
fig = plot.figure
else:
raise NotImplementedError()
fig.tight_layout()
# fig.tight_layout(pad=0, h_pad=0, w_pad=0)
fig.savefig(save_path, **(save_kwargs or {}))

return plot

def to_image(
self,
filepath,
width,
height,
dpi=100,
format='png',
transparent=True,
bgcolor='FFFFFF'
):
# NOTE: This method assumes default DPI value.
f = self
if self.domain.crs != GeogCS(6371229):
f = self.to_regular()
# dpi = plt.rcParams['figure.dpi']
w, h = width / dpi, height / dpi
f.plot(
figsize=(w, h),
add_colorbar=False,
save_path=filepath,
save_kwargs={
'transparent': transparent,
'pad_inches': 0,
'dpi': dpi,
# 'bbox_inches': [[0, 0], [w, h]]
},
clean_image=True
)

def to_geojson(self, target=None):
self.load()
if self.domain.type is DomainType.POINTS:
Expand Down Expand Up @@ -1471,7 +1541,12 @@ def to_geojson(self, target=None):
):
# HACK: The case `self.domain.type is None` is included to be able
# to handle undefined domain types temporarily.
result = {"data": []}
if self.time.size == 1:
result = {}
else:
raise NotImplementedError(
f"multiple times are not supported for geojson"
)
field = (
self
if isinstance(self.domain.crs, GeogCS)
Expand Down Expand Up @@ -1533,7 +1608,7 @@ def to_geojson(self, target=None):
"properties": {self.name: float(value)},
}
time_data["features"].append(feature)
result["data"].append(time_data)
result = time_data
else:
raise NotImplementedError(
f"'self.domain.type' is {self.domain.type}, which is currently"
Expand Down

0 comments on commit 0526755

Please sign in to comment.