117 lines
3.5 KiB
Python
117 lines
3.5 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Optional, Tuple
|
|
|
|
import numpy as np
|
|
import rasterio
|
|
from rasterio.errors import RasterioIOError
|
|
from rasterio.transform import rowcol, xy
|
|
from rasterio.windows import Window
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class RasterHandle:
|
|
path: Path
|
|
dataset: rasterio.io.DatasetReader
|
|
used_map_file: bool
|
|
|
|
@property
|
|
def width(self) -> int:
|
|
return self.dataset.width
|
|
|
|
@property
|
|
def height(self) -> int:
|
|
return self.dataset.height
|
|
|
|
@property
|
|
def crs(self):
|
|
return self.dataset.crs
|
|
|
|
@property
|
|
def transform(self):
|
|
return self.dataset.transform
|
|
|
|
def close(self) -> None:
|
|
self.dataset.close()
|
|
|
|
|
|
def open_georaster(map_path: Optional[str | Path] = None, tif_path: Optional[str | Path] = None) -> RasterHandle:
|
|
"""Open .map if possible, otherwise .tif. .map gives georeferencing through GDAL's MAP driver."""
|
|
last_err = None
|
|
if map_path:
|
|
try:
|
|
p = Path(map_path)
|
|
ds = rasterio.open(p)
|
|
LOG.info("Opened georeferenced MAP dataset: %s", p)
|
|
return RasterHandle(p, ds, used_map_file=True)
|
|
except Exception as e: # noqa: BLE001
|
|
last_err = e
|
|
LOG.warning("Could not open MAP dataset %s: %s", map_path, e)
|
|
if tif_path:
|
|
try:
|
|
p = Path(tif_path)
|
|
ds = rasterio.open(p)
|
|
LOG.info("Opened TIFF dataset: %s", p)
|
|
return RasterHandle(p, ds, used_map_file=False)
|
|
except RasterioIOError as e:
|
|
last_err = e
|
|
raise RuntimeError(f"Could not open raster. Last error: {last_err}")
|
|
|
|
|
|
def read_window_rgb(ds: rasterio.io.DatasetReader, window: Window) -> np.ndarray:
|
|
"""Read a raster window as uint8 RGB HxWx3."""
|
|
arr = ds.read(window=window, boundless=True, fill_value=255)
|
|
if arr.ndim != 3:
|
|
raise ValueError(f"Expected band-first array, got shape={arr.shape}")
|
|
if arr.shape[0] >= 3:
|
|
arr = arr[:3]
|
|
elif arr.shape[0] == 1:
|
|
arr = np.repeat(arr, 3, axis=0)
|
|
arr = np.moveaxis(arr, 0, -1)
|
|
if arr.dtype != np.uint8:
|
|
arr = np.clip(arr, 0, 255).astype(np.uint8)
|
|
return arr
|
|
|
|
|
|
def iter_windows(width: int, height: int, tile_size: int, overlap: int):
|
|
step = max(1, tile_size - overlap)
|
|
y = 0
|
|
while y < height:
|
|
x = 0
|
|
h = min(tile_size, height - y)
|
|
while x < width:
|
|
w = min(tile_size, width - x)
|
|
yield Window(x, y, w, h)
|
|
if x + tile_size >= width:
|
|
break
|
|
x += step
|
|
if y + tile_size >= height:
|
|
break
|
|
y += step
|
|
|
|
|
|
def lonlat_to_pixel(ds: rasterio.io.DatasetReader, lon: float, lat: float) -> Tuple[int, int]:
|
|
"""Convert lon/lat to row/col. Assumes the raster CRS accepts lon/lat or GDAL handles geographic transform.
|
|
|
|
For a production version, reproject coordinates into ds.crs with pyproj first. The PoC does that in coordinates.py.
|
|
"""
|
|
row, col = rowcol(ds.transform, lon, lat)
|
|
return int(row), int(col)
|
|
|
|
|
|
def pixel_to_lonlat(ds: rasterio.io.DatasetReader, row: int, col: int) -> Tuple[float, float]:
|
|
x, y = xy(ds.transform, row, col, offset="center")
|
|
return float(x), float(y)
|
|
|
|
|
|
def has_real_georef(ds: rasterio.io.DatasetReader) -> bool:
|
|
try:
|
|
return ds.crs is not None and not ds.transform.is_identity
|
|
except Exception:
|
|
return False
|