#!/usr/bin/env ambari-python-wrap
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import optparse
import copy
import os
import re
import sys

# The global prefix and current directory
root = "/usr/odp"
odpVersion = "3.3.6.4-1"
current = root + "/current"
versionRegex = re.compile('[-.]')

spark_wrapper_scripts = ["run-example", "pyspark", "spark-class", "sparkR", "spark-shell", "spark-sql", "spark-submit"]

# The packages and where in the release they should point to
leaves = {
           "accumulo-client": "accumulo",
           "accumulo-gc": "accumulo",
           "accumulo-master": "accumulo",
           "accumulo-monitor": "accumulo",
           "accumulo-tablet": "accumulo",
           "accumulo-tracer": "accumulo",
           "atlas-server": "atlas",
           "atlas-client": "atlas",
           "beacon-server": "beacon",
           "beacon-client": "beacon",
           "cruise-control": "cruise-control",
           "cruise-control3": "cruise-control3",
           "falcon-client": "falcon",
           "falcon-server": "falcon",
           "flume-server": "flume",
           "hadoop-client": "hadoop",
           "hadoop-hdfs-client": "hadoop-hdfs",
           "hadoop-mapreduce-client": "hadoop-mapreduce",
           "hadoop-yarn-client": "hadoop-yarn",
           "hadoop-hdfs-datanode": "hadoop-hdfs",
           "hadoop-hdfs-journalnode": "hadoop-hdfs",
           "hadoop-hdfs-nfs3": "hadoop-hdfs",
           "hadoop-hdfs-namenode": "hadoop-hdfs",
           "hadoop-hdfs-portmap": "hadoop-hdfs",
           "hadoop-hdfs-secondarynamenode": "hadoop-hdfs",
           "hadoop-hdfs-zkfc": "hadoop-hdfs",
           "hadoop-httpfs": "hadoop-httpfs",
           "hadoop-mapreduce-historyserver": "hadoop-mapreduce",
           "hadoop-yarn-resourcemanager": "hadoop-yarn",
           "hadoop-yarn-nodemanager": "hadoop-yarn",
           "hadoop-yarn-timelineserver": "hadoop-yarn",
           "hadoop-yarn-timelinereader": "hadoop-yarn",
           "hadoop-yarn-registrydns": "hadoop-yarn",
           "hadoop-hdfs-dfsrouter": "hadoop-hdfs",
           "hbase-client": "hbase",
           "hbase-master": "hbase",
           "hbase-regionserver": "hbase",
           "hive-client": "hive",
           "hive-metastore": "hive",
           "hive-server2": "hive",
           "hive-webhcat": "hive-hcatalog",
           "hive-server2-hive": "hive",
           "hive-server2-hive2": "hive2",
           "kafka-broker": "kafka",
           "kafka3-broker": "kafka3",
           "hive_warehouse_connector": "hive_warehouse_connector",
           "hue": "hue",
           "knox-server": "knox",
           "livy3-client": "livy3",
           "livy3-server": "livy3",
           "livy3_3_3_3-client": "livy3_3_3_3",
           "livy3_3_3_3-server": "livy3_3_3_3",
           "livy3_3_5_1-client": "livy3_3_5_1",
           "livy3_3_5_1-server": "livy3_3_5_1",
           "livy4-client": "livy4",
           "livy4-server": "livy4",
           "mahout-client": "mahout",
           "nifi": "nifi",
           "nifi-toolkit": "nifi-toolkit",
           "nifi-registry": "nifi-registry",
           "nifi2": "nifi2",
           "nifi2-toolkit": "nifi-toolkit2",
           "nifi2-registry": "nifi-registry2",
           "oozie-client": "oozie",
           "oozie-server": "oozie",
           "ozone-client": "ozone",
           "ozone-manager": "ozone",
           "ozone-scm": "ozone",
           "ozone-recon": "ozone",
           "ozone-s3g": "ozone",
           "ozone2-client": "ozone2",
           "ozone2-manager": "ozone2",
           "ozone2-scm": "ozone2",
           "ozone2-recon": "ozone2",
           "ozone2-s3g": "ozone2",
           "phoenix-client": "phoenix",
           "phoenix-server": "phoenix",
           "pig-client": "pig",
           "slider-client": "slider",
           "sqoop-client": "sqoop",
           "sqoop-server": "sqoop",
           "storm-client": "storm",
           "storm-nimbus": "storm",
           "storm-supervisor": "storm",
           "storm-slider-client": "storm-slider-client",
           "tez-client": "tez",
           "zeppelin-server": "zeppelin",
           "zookeeper-client": "zookeeper",
           "zookeeper-server": "zookeeper",
           "ranger-admin" : "ranger-admin",
           "ranger-kms" : "ranger-kms",
           "ranger-usersync" : "ranger-usersync",
           "ranger-tagsync" : "ranger-tagsync",
           "shc" : "shc",
           "spark3-client" : "spark3",
           "spark3-thriftserver" : "spark3",
           "spark3-historyserver" : "spark3",
           "spark4-client" : "spark4",
           "spark4-thriftserver" : "spark4",
           "spark4-historyserver" : "spark4",
           "spark3-connect" : "spark3",
           "spark-rapids-355": "spark-rapids-355",
           "spark-rapids-351": "spark-rapids-351",
           "spark-rapids-333": "spark-rapids-333",
           "spark-rapids-411": "spark-rapids-411",
           "spark3_3_3_3-client" : "spark3_3_3_3",
           "spark3_3_3_3-thriftserver" : "spark3_3_3_3",
           "spark3_3_3_3-historyserver" : "spark3_3_3_3",
           "spark3_3_5_1-client" : "spark3_3_5_1",
           "spark3_3_5_1-thriftserver" : "spark3_3_5_1",
           "spark3_3_5_1-historyserver" : "spark3_3_5_1",
           "druid-overlord" : "druid",
           "druid-middlemanager" : "druid",
           "druid-coordinator" : "druid",
           "druid-broker" : "druid",
           "druid-historical" : "druid",
           "druid-router" : "druid",
           "druid-superset" : "superset",
           "superset" : "superset",
           "spark-atlas-connector":"spark-atlas-connector",
           "spark-schema-registry":"spark-schema-registry",
           "beacon" : "beacon",
           "registry": "registry",
           "airflow" : "airflow",
           "pinot" : "pinot",
           "trino" : "trino",
           "flink-client" : "flink",
           "flink-historyserver" : "flink",
           "mlflow" : "mlflow",
           "kudu" : "kudu",
           "clickhouse": "clickhouse",
           "celeborn-master": "celeborn",
           "celeborn-worker": "celeborn",
           "celeborn-client": "celeborn",
           "trino-gateway": "trino-gateway",
}

# Define the aliases and the list of leaves they correspond to
aliases = {
  "accumulo-server": ["accumulo-gc",
                      "accumulo-master",
                      "accumulo-monitor",
                      "accumulo-tablet",
                      "accumulo-tracer"],
  "all": leaves.keys(),
  "client" : ["accumulo-client",
              "atlas-client",
              "beacon-client",
              "celeborn-client",
              "falcon-client",
              "hadoop-client",
              "hbase-client",
              "mahout-client",
              "oozie-client",
              "ozone-client",
              "ozone2-client",
              "phoenix-client",
              "slider-client",
              "sqoop-client",
              "storm-client",
              "zookeeper-client",
              "spark3-client",
              "spark4-client",
              "spark3_3_3_3-client",
              "spark3_3_5_1-client",
              "shc"],
  "hadoop-hdfs-server": ["hadoop-hdfs-datanode",
                         "hadoop-hdfs-journalnode",
                         "hadoop-hdfs-nfs3",
                         "hadoop-hdfs-namenode",
                         "hadoop-hdfs-secondarynamenode"],
  "hadoop-mapreduce-server": ["hadoop-mapreduce-historyserver"],
  "hadoop-yarn-server": ["hadoop-yarn-resourcemanager",
                         "hadoop-yarn-nodemanager",
                         "hadoop-yarn-timelineserver",
                         "hadoop-yarn-timelinereader",
                         "hadoop-yarn-registrydns"],
  "hive-server": ["hive-metastore",
                  "hive-server2",
                  "hive-webhcat"],
}

locked_contexts = {
  "hadoop-client": [("hadoop-hdfs-client", "hadoop-hdfs"),
                    ("hadoop-httpfs", "hadoop-httpfs"),
                    ("hadoop-yarn-client", "hadoop-yarn"),
                    ("hadoop-mapreduce-client", "hadoop-mapreduce"),
                    ("hive-client", "hive"),
                    ("pig-client", "pig"),
                    ("tez-client", "tez"),
                    ("livy-client", "livy"),
                    ("livy3-client", "livy3"),
                    ("livy3_3_3_3-client", "livy3_3_3_3"),
                    ("livy3_3_5_1-client", "livy3_3_5_1"),
                    ("livy4-client", "livy4"),
                    ("livy3-client", "livy2"),
                    ("spark4-client", "spark4"),
                    ("shc", "shc"),
                    ("spark3-client", "spark3"),
                    ("spark3_3_3_3-client", "spark3_3_3_3"),
                    ("spark3_3_5_1-client", "spark3_3_5_1")]
}

command_symlinks = {
  "accumulo" : "accumulo-client/bin/accumulo",
  "atlas-start" : "atlas-server/bin/metadata-start.sh",
  "atlas-stop" : "atlas-server/bin/metadata-stop.sh",
  "beacon" : "beacon-client/bin/beacon.py",
  "beeline" : "hive-client/bin/beeline",
  "falcon" : "falcon-client/bin/falcon",
  "flume-ng"  : "flume-server/bin/flume-ng",
  "hadoop" : "hadoop-client/bin/hadoop",
  "hbase" : "hbase-client/bin/hbase",
  "phoenix-sqlline" : "phoenix-client/bin/sqlline.py",
  "phoenix-sqlline-thin" : "phoenix-client/bin/sqlline-thin.py",
  "phoenix-psql" : "phoenix-client/bin/psql.py",
  "phoenix-queryserver" : "phoenix-server/bin/queryserver.py",
  "hcat" : "hive-client/../hive-hcatalog/bin/hcat",
  "hdfs" : "hadoop-hdfs-client/bin/hdfs",
  "hive" : "hive-client/bin/hive",
  "hiveserver2" : "hive-server2/bin/hiveserver2",
  "mahout" : "mahout-client/bin/mahout",
  "mapred" : "hadoop-mapreduce-client/bin/mapred",
  "oozie" : "oozie-client/bin/oozie",
  "oozied.sh" : "oozie-server/bin/oozied.sh",
  "ozone" : "ozone-client/bin/ozone",
  "ozone2" : "ozone2-client/bin/ozone2",
  "pig" : "pig-client/bin/pig",
  "slider": "slider-client/bin/slider",
  "sqoop" : "sqoop-client/bin/sqoop",
  "sqoop-codegen" : "sqoop-client/bin/sqoop-codegen",
  "sqoop-create-hive-table" : "sqoop-client/bin/sqoop-create-hive-table",
  "sqoop-eval" : "sqoop-client/bin/sqoop-eval",
  "sqoop-export" : "sqoop-client/bin/sqoop-export",
  "sqoop-help" : "sqoop-client/bin/sqoop-help",
  "sqoop-import" : "sqoop-client/bin/sqoop-import",
  "sqoop-import-all-tables" : "sqoop-client/bin/sqoop-import-all-tables",
  "sqoop-job" : "sqoop-client/bin/sqoop-job",
  "sqoop-list-databases" : "sqoop-client/bin/sqoop-list-databases",
  "sqoop-list-tables" : "sqoop-client/bin/sqoop-list-tables",
  "sqoop-merge" : "sqoop-client/bin/sqoop-merge",
  "sqoop-metastore" : "sqoop-server/bin/sqoop-metastore",
  "sqoop-version" : "sqoop-client/bin/sqoop-version",
  "storm" : "storm-client/bin/storm",
  "worker-lanucher": "storm-client/bin/worker-launcher",
  "storm-slider": "storm-slider-client/bin/storm-slider",
  "kafka" : "kafka-broker/bin/kafka",
  "yarn" : "hadoop-yarn-client/bin/yarn",
  "zookeeper-client" : "zookeeper-client/bin/zookeeper-client",
  "zookeeper-server" : "zookeeper-server/bin/zookeeper-server",
  "zookeeper-server-cleanup" : "zookeeper-server/bin/zookeeper-server-cleanup",
  "ranger-admin-start" : "ranger-admin/ews/ranger-admin-start",
  "ranger-admin-stop" : "ranger-admin/ews/ranger-admin-stop",
  "ranger-kms" : "ranger-kms/ranger-kms",
  "ranger-usersync-start" : "ranger-usersync/ranger-usersync-start",
  "ranger-usersync-stop" : "ranger-usersync/ranger-usersync-stop",
}

bin_directory = "/usr/bin"

# Given a package or alias name, get the full list of packages
def getPackages( name ):
  if name is None:
    return leaves.keys()
  if name in aliases:
    return aliases[name]
  if name in leaves:
    return [ name ]

  print("ERROR: Invalid package - " + name)
  print()
  printPackages()
  sys.exit(1)

# Print the status of each of the given packages
def listPackages( packages ):
  if packages is None:
    packages = leaves

  packages = sorted(packages)
  for pkg in packages:
    linkname = current + "/" + pkg
    if os.path.isdir(linkname):
      print(pkg + " - " +
            os.path.basename(os.path.dirname(os.readlink(linkname))))
    else:
      print(pkg + " - None")

# Print the avaialable package names
def printPackages():
  packages = leaves.keys()
  packages=sorted(packages)
  print("Packages:")
  for pkg in packages:
    print(" " + pkg)
  groups = aliases.keys()
  groups = sorted(groups)
  print("Aliases:")
  for pkg in groups:
    print(" " + pkg)

#Print whether given package is supported or not
def isSupportedPacakge(package):
    packages = leaves.keys()
    if package in packages:
      print("%s" % package)
    else:
      print(None)
      sys.exit(1)

# Print the installed packages
def printVersions():
  result = {}
  for f in os.listdir(root):
    if f not in [".", "..", "current", "share", "lost+found"]:
      try:
        result[tuple(map(int, versionRegex.split(f)))] = f
      except ValueError:
        print ("ERROR: Unexpected file/directory found in %s: %s" % (root, f))
        sys.exit(1)

  keys = result.keys()
  keys = sorted(keys)
  for k in keys:
    print(result[k])

# Set the list of packages to the given version
def setPackages(packages, version, rpm_mode):
  if packages is None or version is None:
    print("ERROR: 'set' command must give both package and version")
    print()
    printHelp()
    sys.exit(1)

  target = root + "/" + version
  if not os.path.isdir(target):
    print("ERROR: Invalid version " + version)
    print()
    print("Valid choices:")
    printVersions()
    sys.exit(1)

  if not os.path.isdir(current):
    os.mkdir(current, 0o755)

  packages = sorted(packages)
  for pkg in packages:
    linkname = current + "/" + pkg
    if os.path.islink(linkname) and rpm_mode:
      continue
    elif os.path.islink(linkname):
      os.remove(linkname)

    if os.path.exists(linkname):
      print("symlink target %s for %s already exists and it is not a symlink." % (linkname, leaves[pkg]))
      sys.exit(1)

    os.symlink(target + "/" + leaves[pkg], linkname)
    if pkg in locked_contexts:
      for (kid, dir) in locked_contexts[pkg]:
        linkname = current + "/" + kid
        if os.path.islink(linkname):
          os.remove(linkname)
        os.symlink(target + "/" + dir, linkname)

# Create command symlinks
def createCommandSymlinks(packages, rpm_mode):
  work_packages = copy.copy(list(packages))
  for pkg in packages:
    if pkg in locked_contexts:
      for (child, dir) in locked_contexts[pkg]:
        work_packages.append(child)
  for symlink in command_symlinks:
    pkg = command_symlinks[symlink].split('/')[0]
    filename = bin_directory + "/" + symlink
    target = current + "/" + command_symlinks[symlink]
    if rpm_mode and os.path.lexists(filename):
      continue
    if pkg in work_packages:
      if not os.path.lexists(filename):
        os.symlink(target, filename)
      elif os.path.islink(filename):
        old_value = os.readlink(filename)
        if old_value != target:
          print("WARNING: Replacing link", filename, "from", old_value)
          os.remove(filename)
          os.symlink(target, filename)
      else:
        if os.path.basename(filename) not in spark_wrapper_scripts:
          print("ERROR:", filename, "is a regular file instead of symlink.")
          print()
          print("Please ensure that the ODP 2.1 (and earlier) packages are")
          print("removed.")

# Do a sanity check on the tables
def sanityCheckTables():
  for alias in aliases:
    for child in aliases[alias]:
      if not child in leaves:
        print("ERROR: Alias", alias, "has bad child", child)
        sys.exit(1)
  locked = set()
  for parent in locked_contexts:
    for (kid, dir) in locked_contexts[parent]:
      locked.add(kid)
  for symlink in command_symlinks:
    parts = command_symlinks[symlink].split('/')
    if not parts[0] in leaves and not parts[0] in locked:
      print("ERROR: command symlink", symlink, "points to an invalid package", \
            parts[0])
      sys.exit(1)

def printHelp():
  print("""
usage: odp-select [-h] [<command>] [<package>] [<version>]

Set the selected version of ODP.

positional arguments:
  <command>   One of set, status, versions, or packages
  <package>   the package name to set
  <version>   the ODP version to set

optional arguments:
  -h, --help  show this help message and exit
  -r, --rpm-mode  if true checks if there is symlink exists and creates the symlink if it doesn't

Commands:
  set      : set the package to a specified version
  status   : show the version of the package
  versions : show the currently installed versions
  packages : show the individual package names
  supports : Check odp-select supports a specific package
""")

def checkCommandParameters(cmd, realLen, rightLen):
  if realLen != rightLen:
    print("ERROR:", cmd, "command takes", rightLen - 1, \
          "parameters, instead of", realLen - 1)
    printHelp()
    sys.exit(1)

 ############################
 #
 # Start of main

sanityCheckTables()

parser = optparse.OptionParser(add_help_option=False)
parser.add_option("-h", "--help", action="store_true", dest="help",
                  help="print help")
parser.add_option("-v", "--version", action="store_true", dest="version",
                  help="print odp version")
parser.add_option("-r", "--rpm-mode", action="store_true", dest="rpm_mode", default=False,
                  help="if true checks if there is symlink exists and creates the symlink if it doesn't")
(options, args) = parser.parse_args()

if options.help:
  printHelp()
elif options.version:
  print(odpVersion)
else:
  if len(args) == 0 or args[0] == 'status':
    if len(args) > 1:
      listPackages(getPackages(args[1]))
    else:
      listPackages(getPackages("all"))
  elif args[0] == "set":
    checkCommandParameters('set', len(args), 3)
    pkgs = getPackages(args[1])
    setPackages(pkgs, args[2], options.rpm_mode)
    createCommandSymlinks(pkgs, options.rpm_mode)
  elif args[0] == "versions":
    checkCommandParameters('versions', len(args), 1)
    printVersions()
  elif args[0] == "packages":
    checkCommandParameters('packages', len(args), 1)
    printPackages()
  elif args[0] == "supports":
    checkCommandParameters('support', len(args), 2)
    isSupportedPacakge(args[1])
  else:
    print("ERROR: Unknown command -", args[0])
    printHelp()
    sys.exit(1)
