Skip to content
Snippets Groups Projects
points2trees.py 20.65 KiB
import os
import multiprocessing
import argparse

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import scipy

from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import DBSCAN
from scipy.spatial import ConvexHull
import networkx as nx

from fsct.tools import *
from fsct.fit_cylinders import RANSAC_helper

import warnings
warnings.filterwarnings('ignore')
pd.options.mode.chained_assignment = None

def generate_path(samples, origins, n_neighbours=200, max_length=0):

    # compute nearest neighbours for each vertex in cluster convex hull
    nn = NearestNeighbors(n_neighbors=n_neighbours).fit(samples[['x', 'y', 'z']])
    distances, indices = nn.kneighbors()    
    from_to_all = pd.DataFrame(np.vstack([np.repeat(samples.clstr.values, n_neighbours), 
                                          samples.iloc[indices.ravel()].clstr.values, 
                                          distances.ravel()]).T, 
                               columns=['source', 'target', 'length'])

    # remove X-X connections
    from_to_all = from_to_all.loc[from_to_all.target != from_to_all.source]

    # and build edge database where edges with min distance between clusters persist
    edges = from_to_all.groupby(['source', 'target']).length.min().reset_index()
    # remove edges that are likely leaps between trees
    edges = edges.loc[edges.length <= max_length]

    # removes isolated origin points i.e. > edge.length
    origins = [s for s in origins if s in edges.source.values] 

    # compute graph
    G = nx.from_pandas_edgelist(edges, edge_attr=['length'])
    distance, shortest_path = nx.multi_source_dijkstra(G, 
                                                       sources=origins,
                                                       weight='length')

    paths = pd.DataFrame(index=distance.keys(), data=distance.values(), columns=['distance'])
    paths.loc[:, 'base'] = params.not_base
    for p in paths.index: paths.loc[p, 'base'] = shortest_path[p][0]
    paths.reset_index(inplace=True)
    paths.columns = ['clstr', 'distance', 't_clstr']
    
    # identify nodes that are branch tips
    node_occurance = {}
    for v in shortest_path.values():
        for n in v:
            if n in node_occurance.keys(): node_occurance[n] += 1
            else: node_occurance[n] = 1

    tips = [k for k, v in node_occurance.items() if v == 1]

    paths.loc[:, 'is_tip'] = False
    paths.loc[paths.clstr.isin(tips), 'is_tip'] = True

    return paths

def cube(pc):
    if len(pc) > 5:
        try:
            vertices = ConvexHull(pc[['x', 'y', 'z']]).vertices
            idx = np.random.choice(vertices, size=len(vertices), replace=False)
            return pc.loc[pc.index[idx]]
        except scipy.spatial.qhull.QhullError as e:
            print(f"Error computing convex hull for group with {len(pc)} points: {e}")
            # Handle the special case as needed, e.g., return the input or a default value
            return pc
    else:
        return pc

if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--tile', '-t', type=str, default='', required=True, help='fsct directory')
    parser.add_argument('--odir', '-o', type=str, required=True, help='output directory')
    parser.add_argument('--tindex', type=str, required=True, help='path to tile index')
    parser.add_argument('--n-tiles', default=3, type=int, help='enlarges the number of tiles i.e. 3x3 or tiles or 5 x 5 tiles')
    parser.add_argument('--overlap', default=False, type=float, help='buffer to crop adjacent tiles')
    parser.add_argument('--slice-thickness', default=.2, type=float, help='slice thickness for constructing graph')
    parser.add_argument('--find-stems-height', default=1.5, type=float, help='height for identifying stems')    
    parser.add_argument('--find-stems-thickness', default=.5, type=float, help='thickness of slice used for identifying stems')
    parser.add_argument('--find-stems-min-radius', default=.025, type=float, help='minimum radius of found stems')
    parser.add_argument('--find-stems-min-points', default=200, type=int, help='minimum number of points for found stems')
    parser.add_argument('--graph-edge-length', default=1, type=float, help='maximum distance used to connect points in graph')
    parser.add_argument('--graph-maximum-cumulative-gap', default=np.inf, type=float, 
                        help='maximum cumulative distance between a base and a cluster')
    parser.add_argument('--min-points-per-tree', default=0, type=int, help='minimum number of points for a identified tree')
    parser.add_argument('--add-leaves', action='store_true', help='add leaf points')
    parser.add_argument('--add-leaves-voxel-length', default=.5, type=float, help='voxel sixe when add leaves')
    parser.add_argument('--add-leaves-edge-length', default=1, type=float, 
                        help='maximum distance used to connect points in leaf graph')
    parser.add_argument('--save-diameter-class', action='store_true', help='save into dimater class directories')
    parser.add_argument('--ignore-missing-tiles', action='store_true', help='ignore missing neighbouring tiles')
    parser.add_argument('--pandarallel', action='store_true', help='use pandarallel')
    parser.add_argument('--verbose', action='store_true', help='print something')
    params = parser.parse_args()
        
    if params.pandarallel:
        try:
            from pandarallel import pandarallel
            pandarallel.initialize(progress_bar=True if params.verbose else False)
        except:
            print('--- pandarallel not installed ---')
            params.pandarallel = False
    
    if params.verbose:
        print('---- parameters ----')
        for k, v in params.__dict__.items():
            print(f'{k:<35}{v}')

    params.not_base = -1
    xyz = ['x', 'y', 'z'] # shorthand

    params.dir, params.fn = os.path.split(params.tile)
    params.n = int(params.fn.split('.')[0])

    params.pc = ply_io.read_ply(params.tile)
    params.pc.loc[:, 'buffer'] = False
    params.pc.loc[:, 'fn'] = params.n

    bbox = {}
    bbox['xmin'], bbox['xmax'] = params.pc.x.min(), params.pc.x.max()
    bbox['ymin'], bbox['ymax'] = params.pc.y.min(), params.pc.y.max()
    bbox = dict2class(bbox)

    # neighbouring tiles to process
    params.ti = pd.read_csv(params.tindex, 
                            sep=' ', 
                            names=['tile', 'x', 'y'])
    n_tiles = NearestNeighbors(n_neighbors=len(params.ti)).fit(params.ti[['x', 'y']])
    distance, indices = n_tiles.kneighbors(params.ti.loc[params.ti.tile == params.n][['x', 'y']])
    # todo: this could be made smarter e.g. using distance
    buffer_tiles = params.ti.loc[indices[0][1:params.n_tiles**2]]['tile'].values

    for i, t in tqdm(enumerate(buffer_tiles),
                     total=len(buffer_tiles),
                     desc='read in neighbouring tiles', 
                     disable=False if params.verbose else True):

        try:
            b_tile = glob.glob(os.path.join(params.dir, f'{t:03}*.ply'))[0]
            tmp = ply_io.read_ply(b_tile)
            if params.overlap:
                tmp = tmp.loc[(tmp.x.between(bbox.xmin - params.overlap, bbox.xmax + params.overlap)) & 
                              (tmp.y.between(bbox.ymin - params.overlap, bbox.ymax + params.overlap))]
            if len(tmp) == 0: continue
            tmp.loc[:, 'buffer'] = True
            tmp.loc[:, 'fn'] = t
            params.pc = params.pc.append(tmp, ignore_index=True)
        except:
            path = os.path.join(params.dir, f'{t:03}*.ply')
            if params.ignore_missing_tiles:
                print(f'tile {path} not available')
            else:
                raise Exception(f'tile {path} not available')
    
    # --- this can be dropeed soon --- 
    if 'nz' in params.pc.columns: params.pc.rename(columns={'nz':'n_z'}, inplace=True)
        
    # save space
    params.pc = params.pc[[c for c in ['x', 'y', 'z', 'n_z', 'label', 'buffer', 'fn']]]
    params.pc[['x', 'y', 'z', 'n_z']] = params.pc[['x', 'y', 'z', 'n_z']].astype(np.float32)
    params.pc[['label', 'fn']] = params.pc[['label', 'fn']].astype(np.int16)

    ### generate skeleton points
    if params.verbose: print('\n----- skeletonisation started -----')

    # extract stems points and slice slice
    stem_pc = params.pc.loc[params.pc.label == 3]

    # slice stem_pc
    stem_pc.loc[:, 'slice'] = (stem_pc.z // params.slice_thickness).astype(int) * params.slice_thickness
    stem_pc.loc[:, 'n_slice'] = (stem_pc.n_z // params.slice_thickness).astype(int)

    # cluster within height slices
    stem_pc.loc[:, 'clstr'] = -1
    label_offset = 0

    for slice_height in tqdm(np.sort(stem_pc.n_slice.unique()), 
                             disable=False if params.verbose else True,
                             desc='slice data vertically and clustering'):

        new_slice = stem_pc.loc[stem_pc.n_slice == slice_height]

        if len(new_slice) > 200:
            dbscan = DBSCAN(eps=.1, min_samples=20).fit(new_slice[xyz])
            new_slice.loc[:, 'clstr'] = dbscan.labels_
            new_slice.loc[new_slice.clstr > -1, 'clstr'] += label_offset
            stem_pc.loc[new_slice.index, 'clstr'] = new_slice.clstr
            label_offset = stem_pc.clstr.max() + 1
    
    # group skeleton points
    grouped = stem_pc.loc[stem_pc.clstr != -1].groupby('clstr')

    if params.verbose: print('fitting convex hulls to clusters')
    if params.pandarallel:
        chull = grouped.parallel_apply(cube) # parallel_apply only works witn pd < 1.3
    else:
        chull = grouped.apply(cube) # don't think works with Jasmin or parallel_apply only works witn pd < 1.3
    
    chull = chull.reset_index(drop=True) 
    
    ### identify possible stems ###
    if params.verbose: print('identifying stems...')
    skeleton = grouped[xyz + ['n_z', 'n_slice', 'slice']].median().reset_index()
    skeleton.loc[:, 'dbh_node'] = False

    # dbh_nodes = skeleton.loc[skeleton.n_slice == params.slice_height].clstr
    find_stems_min = int(params.find_stems_height // params.slice_thickness) 
    find_stems_max = int((params.find_stems_height + params.find_stems_thickness) // params.slice_thickness)  + 1
    dbh_nodes_plus = skeleton.loc[skeleton.n_slice.between(find_stems_min, find_stems_max)].clstr
    dbh_slice = stem_pc.loc[stem_pc.clstr.isin(dbh_nodes_plus)]

    if len(dbh_slice) > 0:

        # remove noise from dbh slice
        nn = NearestNeighbors(n_neighbors=10).fit(dbh_slice[xyz])
        distances, indices = nn.kneighbors()
        dbh_slice.loc[:, 'nn'] = distances[:, 1:].mean(axis=1)
        dbh_slice = dbh_slice.loc[dbh_slice.nn < dbh_slice.nn.quantile(q=.9)]

        # run dbscan over dbh_slice to find potential stems
        dbscan = DBSCAN(eps=.2, min_samples=50).fit(dbh_slice[['x', 'y']])
        dbh_slice.loc[:, 'clstr_db'] = dbscan.labels_
        dbh_slice = dbh_slice.loc[dbh_slice.clstr_db > -1]
        dbh_slice.loc[:, 'cclstr'] = dbh_slice.groupby('clstr_db').clstr.transform('min')

        if len(dbh_slice) > 10: 

            # ransac cylinder fitting
            if params.verbose: print('fitting cylinders to possible stems...')
            if params.pandarallel:
                dbh_cylinder = dbh_slice.groupby('cclstr').parallel_apply(RANSAC_helper, 100, ).to_dict()
            else:
                dbh_cylinder = dbh_slice.groupby('cclstr').apply(RANSAC_helper, 100, ).to_dict()
            dbh_cylinder = pd.DataFrame(dbh_cylinder).T
            dbh_cylinder.columns = ['radius', 'centre', 'CV', 'cnt']
            dbh_cylinder.loc[:, ['x', 'y', 'z']] = [[*row.centre] for row in dbh_cylinder.itertuples()]
            dbh_cylinder = dbh_cylinder.drop(columns=['centre']).astype(float)

            # identify clusters where cylinder CV <= .75 and label as nodes
            skeleton.loc[skeleton.clstr.isin(dbh_cylinder.loc[(dbh_cylinder.radius > params.find_stems_min_radius) &
                                                              (dbh_cylinder.cnt > params.find_stems_min_points) &
                                                              (dbh_cylinder.CV <= .15)].index.values), 'dbh_node'] = True

    in_tile_stem_nodes = skeleton.loc[(skeleton.dbh_node) & 
                                      (skeleton.x.between(bbox.xmin, bbox.xmax)) &
                                      (skeleton.y.between(bbox.ymin, bbox.ymax))].clstr
    
    # generates paths through all stem points
    if params.verbose: print('generating graph, this may take a while...')
    wood_paths = generate_path(chull, 
                               skeleton.loc[skeleton.dbh_node].clstr, 
                               n_neighbours=200, 
                               max_length=params.graph_edge_length)

    # removes paths that are longer for same clstr
    wood_paths = wood_paths.sort_values(['clstr', 'distance'])
    wood_paths = wood_paths.loc[~wood_paths['clstr'].duplicated()] 
    
    # remove clusters that are linked to a base by a cumulative
    # distance greater than X 
    wood_paths = wood_paths.loc[wood_paths.distance <= params.graph_maximum_cumulative_gap]

    if params.verbose: print('merging skeleton points with graph')
    stems = pd.merge(skeleton, wood_paths, on='clstr', how='left')

    # give a unique colour to each tree (helps with visualising)
    stems.drop(columns=[c for c in stems.columns if c.startswith('red') or 
                                                    c.startswith('green') or 
                                                    c.startswith('blue')], inplace=True)

    # generate unique RGB for each stem
    unique_stems = stems.t_clstr.unique()
    RGB = pd.DataFrame(data=np.vstack([unique_stems, 
                                       np.random.randint(0, 255, size=(3, len(unique_stems)))]).T, 
                       columns=['t_clstr', 'red', 'green', 'blue'])
    RGB.loc[RGB.t_clstr == params.not_base, :] = [np.nan, 211, 211, 211] # color unassigned points grey
    stems = pd.merge(stems, RGB, on='t_clstr', how='right')

    # read in all "stems" tiles and assign all stem points to a tree
    trees = pd.merge(stem_pc, 
                     stems[['clstr', 't_clstr', 'distance', 'red', 'green', 'blue']], 
                     on='clstr')
    trees.loc[:, 'cnt'] = trees.groupby('t_clstr').t_clstr.transform('count')
    trees = trees.loc[trees.cnt > params.min_points_per_tree]
    in_tile_stem_nodes = trees.loc[trees.t_clstr.isin(in_tile_stem_nodes)].t_clstr.unique()

    # write out all trees
    params.base_I, I = {}, 0
    for i, b in tqdm(enumerate(dbh_cylinder.loc[in_tile_stem_nodes].sort_values('radius', ascending=False).index), 
                     total=len(in_tile_stem_nodes), 
                     desc='writing stems to file', 
                     disable=False if params.verbose else True):

        if b == params.not_base: 
            continue
    
        if params.save_diameter_class:
            d_dir = f'{(dbh_cylinder.loc[b].radius * 2 // .1) / 10:.1f}'
            if not os.path.isdir(os.path.join(params.odir, d_dir)):
                os.makedirs(os.path.join(params.odir, d_dir))
            ply_io.write_ply(os.path.join(params.odir, d_dir, f'{params.n:03}_T{I}.leafoff.ply'), 
                             trees.loc[trees.t_clstr == b])  
        else:
            ply_io.write_ply(os.path.join(params.odir, f'{params.n:03}_T{I}.leafoff.ply'), 
                             trees.loc[trees.t_clstr == b])
        params.base_I[b] = I
        I += 1  

    if params.add_leaves:
        
        if params.verbose: print('adding leaves to stems, this may take a while...')

        # link stem number to clstr
        stem2tlsctr = stems[['clstr', 't_clstr']].loc[stems.t_clstr != params.not_base].set_index('clstr').to_dict()['t_clstr']
        chull.loc[:, 'stem'] = chull.clstr.map(stem2tlsctr)

        # identify unlabelled woody points to add back to leaves
        unlabelled_wood = chull.loc[[True if np.isnan(s) else False for s in chull.stem]]
        unlabelled_wood = stem_pc.loc[stem_pc.clstr.isin(unlabelled_wood.clstr.to_list() + [-1])]

        # extract wood points that are attributed to a base and that are the 
        # the last clstr of the graph i.e. a tip
        is_tip = wood_paths.set_index('clstr')['is_tip'].to_dict()
        chull = chull.loc[[False if np.isnan(s) else True for s in chull.stem]]
        chull.loc[:, 'is_tip'] = chull.clstr.map(is_tip)
        chull = chull.loc[(chull.is_tip) & (chull.n_z > params.find_stems_height)]
        chull.loc[:, 'xlabel'] = 2

        # process leaf points
        lvs = params.pc.loc[(params.pc.label == 1) & (params.pc.n_z >= 2)].copy()
        lvs = lvs.append(unlabelled_wood, ignore_index=True)
        lvs.reset_index(inplace=True)

        # voxelise
        lvs = voxelise(lvs, length=params.add_leaves_voxel_length)
        lvs_gb = lvs.groupby('VX')[xyz]
        lvs_min = lvs_gb.min()
        lvs_max = lvs_gb.max()
        lvs_med = lvs_gb.median()

        # find faces of leaf voxels and create database 
        cnrs = np.vstack([lvs_min.x, lvs_med.y, lvs_med.z]).T
        clstr = np.tile(np.arange(len(lvs_min.index)) + 1 + chull.clstr.max(), 6)
        VX = np.tile(lvs_min.index, 6)
        cnrs = np.vstack([cnrs, np.vstack([lvs_max.x, lvs_med.y, lvs_med.z]).T])
        cnrs = np.vstack([cnrs, np.vstack([lvs_med.x, lvs_min.y, lvs_med.z]).T])
        cnrs = np.vstack([cnrs, np.vstack([lvs_med.x, lvs_max.y, lvs_med.z]).T])
        cnrs = np.vstack([cnrs, np.vstack([lvs_med.x, lvs_med.y, lvs_min.z]).T])
        cnrs = np.vstack([cnrs, np.vstack([lvs_med.x, lvs_med.y, lvs_max.z]).T])
        cnrs = pd.DataFrame(cnrs, columns=['x', 'y', 'z'])
        cnrs.loc[:, 'xlabel'] = 1
        cnrs.loc[:, 'clstr'] = clstr
        cnrs.loc[:, 'VX'] = VX

        # and combine leaves and wood
        branch_and_leaves = cnrs.append(chull[['x', 'y', 'z', 'label', 'stem', 'xlabel', 'clstr']])
        branch_and_leaves.reset_index(inplace=True, drop=True)

        # find neighbouring branch and leaf points - used as entry points
        nn = NearestNeighbors(n_neighbors=2).fit(branch_and_leaves[xyz])
        distances, indices = nn.kneighbors()   
        closest_point_to_leaf = indices[:len(cnrs), :].flatten() # only leaf points
        idx = np.isin(closest_point_to_leaf, branch_and_leaves.loc[branch_and_leaves.xlabel == 2].index)
        close_branch_points = closest_point_to_leaf[idx] # points where the branch is closest

        # remove all branch points that are not close to leaves
        idx = np.hstack([branch_and_leaves.iloc[:len(cnrs)].index.values, close_branch_points])
        bal = branch_and_leaves.loc[branch_and_leaves.index.isin(np.unique(idx))]

        # generate a leaf paths graph
        leaf_paths = generate_path(bal, 
                                   bal.loc[bal.xlabel == 2].clstr.unique(), 
                                   max_length=1, # i.e. any leaves which are separated by greater are ignored
                                   n_neighbours=20)
             
        leaf_paths = leaf_paths.sort_values(['clstr', 'distance'])
        leaf_paths = leaf_paths.loc[~leaf_paths['clstr'].duplicated()] # removes duplicate paths
        leaf_paths = leaf_paths.loc[leaf_paths.distance > 0] # removes within cluseter paths 

        # linking indexs to stem number
        top2stem = branch_and_leaves.loc[branch_and_leaves.xlabel == 2].set_index('clstr')['stem'].to_dict()
        leaf_paths.loc[:, 't_clstr'] = leaf_paths.t_clstr.map(top2stem)
        #     paths.loc[:, 'stem'] = paths.stem_.map(base2i)

        # linking index to VX number
        index2VX = branch_and_leaves.loc[branch_and_leaves.xlabel == 1].set_index('clstr')['VX'].to_dict()
        leaf_paths.loc[:, 'VX'] = leaf_paths['clstr'].map(index2VX)

        # colour the same as stem
        lvs = pd.merge(lvs, leaf_paths[['VX', 't_clstr', 'distance']], on='VX', how='left')

        # and save
        for lv in tqdm(in_tile_stem_nodes):

            I = params.base_I[lv]

            wood_fn = glob.glob(os.path.join(params.odir, '*', f'{params.n:03}_T{I}.leafoff.ply'))[0]

            stem = ply_io.read_ply(os.path.join(wood_fn))
            stem.loc[:, 'wood'] = 1

            l2a = lvs.loc[lvs.t_clstr == lv]
            if len(l2a) > 0:
                l2a.loc[:, 'wood'] = 0
                
                # colour the same as stem
                rgb = RGB.loc[RGB.t_clstr == lv][['red', 'green', 'blue']].values[0] * 1.2
                l2a.loc[:, ['red', 'green', 'blue']] = [c if c <= 255 else 255 for c in rgb]

                stem = stem.append(l2a[['x', 'y', 'z', 'label', 'red', 'green', 'blue', 't_clstr', 'wood', 'distance']])

            stem = stem.loc[~stem.duplicated()]
            ply_io.write_ply(wood_fn.replace('leafoff', 'leafon'), 
                             stem[['x', 'y', 'z', 'red', 'green', 'blue', 'label', 't_clstr', 'wood', 'distance']])
            if params.verbose: print(f"leaf on saved to: {wood_fn.replace('leafoff', 'leafon')}")