@numba.stencil(cval=0.0)
def grid_cell_wall_area(
    building_height=np.array([[]]), dx: float = 1.0, dy: float = 1.0
) -> numba.float64[:, :]:
    # Using negative difference helps with out-of-bounds (which are cval=0.0).
    dz_bottom = building_height[0, 0] - building_height[-1, 0]
    dz_top = building_height[0, 0] - building_height[1, 0]
    dz_left = building_height[0, 0] - building_height[0, -1]
    dz_right = building_height[0, 0] - building_height[0, 1]

    grid_wall_area = 0.0

    # Checking for negative difference only ensures the walls aren't double-counted.
    if dz_top < 0.0:
        grid_wall_area -= dz_top * dx  # Top wall
    if dz_bottom < 0.0:
        grid_wall_area -= dz_bottom * dx  # Bottom wall
    if dz_left < 0.0:
        grid_wall_area -= dz_left * dy  # Left wall
    if dz_right < 0.0:
        grid_wall_area -= dz_right * dy  # Right wall

    return grid_wall_area


@numba.njit(parallel=True, looplift=True)
def compute_wall_area(
    building_height=np.array([[]]),
    dx: float = 1.0,
    dy: float = 1.0,
    nx: int = 0,
    ny: int = 0,
) -> float:
    wall_area = np.sum(grid_cell_wall_area(building_height, dx, dy))

    return wall_area


def urban_parameters_from_morphology(
    lcz_map: xr.DataArray,
    usm_driver: xr.Dataset,
) -> xr.Dataset:

    zones = np.unique(lcz_map)
    params = {}
    dx_usm = float(usm_driver.x[1] - usm_driver.x[0])
    dy_usm = float(usm_driver.y[1] - usm_driver.y[0])
    nx_usm = usm_driver.x.size
    ny_usm = usm_driver.y.size
    extent_x_usm = dx_usm * nx_usm
    extent_y_usm = dy_usm * ny_usm

    for zone in zones:
        params[zone] = {}

        lcz_mask = lcz_map == zone
        building_mask = lcz_mask & (usm_driver["building_type"] > 0)
        street_mask = lcz_mask & (usm_driver["street_type"] > 0)

        total_gridpoints = lcz_mask.sum()
        total_area = total_gridpoints * dx_usm * dy_usm

        # Mean building height
        params[zone]["building_height"] = float(
            usm_driver["buildings_2d"].where(building_mask).mean()
        )

        # Urban fraction
        params[zone]["urban_fraction"] = float(
            (building_mask | street_mask).sum() / total_gridpoints
        )

        urban_area = total_area * params[zone]["urban_fraction"]

        # Plan area fraction
        params[zone]["building_plan_area_fraction"] = float(
            building_mask.sum() / total_gridpoints
        )

        # Compute total wall area
        building_height = (
            usm_driver["buildings_2d"].where(building_mask).fillna(0.0)
        )
        wall_area = compute_wall_area(
            building_height=building_height.data,
            dx=dx_usm,
            dy=dy_usm,
            nx=nx_usm,
            ny=ny_usm,
        )

        wall_area_density = wall_area / total_area

        # Street canyon aspect ratio
        params[zone]["street_canyon_aspect_ratio"] = float(
            0.5
            * wall_area_density
            / (1.0 - params[zone]["building_plan_area_fraction"])
        )

        # Frontal area fraction
        params[zone]["building_frontal_area_fraction"] = float(
            wall_area_density / 4
        )

    return params