From 4a44d504dd1da5412a1bb3f0ea58dd8fbf8b7c85 Mon Sep 17 00:00:00 2001 From: Jason Lam Date: Tue, 21 Mar 2017 16:43:47 -0400 Subject: [PATCH] TensorFlow on YARN --- yarn/README.md | 151 ++ yarn/bin/README.md | 145 ++ yarn/bin/ytf-submit | 284 +++ yarn/example/README | 38 + yarn/example/mnist.py | 196 +++ yarn/pom.xml | 187 ++ .../hadoop/yarn/ApplicationMaster.java | 1420 +++++++++++++++ .../org/tensorflow/hadoop/yarn/Client.java | 738 ++++++++ .../hadoop/yarn/Log4jPropertyHelper.java | 60 + .../tensorflow/hadoop/yarn/TFConstants.java | 62 + .../hadoop/yarn/TFContainerRequest.java | 58 + .../org/tensorflow/hadoop/yarn/TFServlet.java | 49 + .../org/tensorflow/hadoop/yarn/TFSession.java | 708 ++++++++ .../python/task-starter/kazoo/__init__.py | 1 + .../main/python/task-starter/kazoo/client.py | 1561 +++++++++++++++++ .../python/task-starter/kazoo/exceptions.py | 209 +++ .../task-starter/kazoo/handlers/__init__.py | 1 + .../task-starter/kazoo/handlers/eventlet.py | 173 ++ .../task-starter/kazoo/handlers/gevent.py | 163 ++ .../task-starter/kazoo/handlers/threading.py | 196 +++ .../task-starter/kazoo/handlers/utils.py | 229 +++ .../main/python/task-starter/kazoo/hosts.py | 25 + .../python/task-starter/kazoo/interfaces.py | 203 +++ .../task-starter/kazoo/loggingsupport.py | 1 + .../task-starter/kazoo/protocol/__init__.py | 1 + .../task-starter/kazoo/protocol/connection.py | 630 +++++++ .../task-starter/kazoo/protocol/paths.py | 55 + .../kazoo/protocol/serialization.py | 415 +++++ .../task-starter/kazoo/protocol/states.py | 237 +++ .../task-starter/kazoo/python2atexit.py | 69 + .../task-starter/kazoo/recipe/__init__.py | 1 + .../task-starter/kazoo/recipe/barrier.py | 215 +++ .../task-starter/kazoo/recipe/counter.py | 94 + .../task-starter/kazoo/recipe/election.py | 79 + .../python/task-starter/kazoo/recipe/lease.py | 130 ++ .../python/task-starter/kazoo/recipe/lock.py | 584 ++++++ .../task-starter/kazoo/recipe/partitioner.py | 423 +++++ .../python/task-starter/kazoo/recipe/party.py | 118 ++ .../python/task-starter/kazoo/recipe/queue.py | 330 ++++ .../task-starter/kazoo/recipe/watchers.py | 419 +++++ .../main/python/task-starter/kazoo/retry.py | 153 ++ .../python/task-starter/kazoo/security.py | 138 ++ .../task-starter/kazoo/testing/__init__.py | 5 + .../task-starter/kazoo/testing/common.py | 308 ++++ .../task-starter/kazoo/testing/harness.py | 165 ++ .../python/task-starter/wrapper/__init__.py | 0 .../python/task-starter/wrapper/__main__.py | 231 +++ .../hadoop/yarn/EmbeddedZKServer.java | 67 + .../org/tensorflow/hadoop/yarn/TestTF.java | 408 +++++ yarn/src/test/resources/env_check.sh | 11 + yarn/src/test/resources/log4j.properties | 19 + yarn/src/test/resources/test.sh | 38 + yarn/src/test/resources/test_task.sh | 1 + yarn/src/test/resources/yarn-site.xml | 21 + yarn/task-starter-assembly.xml | 31 + 55 files changed, 12254 insertions(+) create mode 100644 yarn/README.md create mode 100644 yarn/bin/README.md create mode 100644 yarn/bin/ytf-submit create mode 100644 yarn/example/README create mode 100644 yarn/example/mnist.py create mode 100644 yarn/pom.xml create mode 100644 yarn/src/main/java/org/tensorflow/hadoop/yarn/ApplicationMaster.java create mode 100644 yarn/src/main/java/org/tensorflow/hadoop/yarn/Client.java create mode 100644 yarn/src/main/java/org/tensorflow/hadoop/yarn/Log4jPropertyHelper.java create mode 100644 yarn/src/main/java/org/tensorflow/hadoop/yarn/TFConstants.java create mode 100644 yarn/src/main/java/org/tensorflow/hadoop/yarn/TFContainerRequest.java create mode 100644 yarn/src/main/java/org/tensorflow/hadoop/yarn/TFServlet.java create mode 100644 yarn/src/main/java/org/tensorflow/hadoop/yarn/TFSession.java create mode 100644 yarn/src/main/python/task-starter/kazoo/__init__.py create mode 100644 yarn/src/main/python/task-starter/kazoo/client.py create mode 100644 yarn/src/main/python/task-starter/kazoo/exceptions.py create mode 100644 yarn/src/main/python/task-starter/kazoo/handlers/__init__.py create mode 100644 yarn/src/main/python/task-starter/kazoo/handlers/eventlet.py create mode 100644 yarn/src/main/python/task-starter/kazoo/handlers/gevent.py create mode 100644 yarn/src/main/python/task-starter/kazoo/handlers/threading.py create mode 100644 yarn/src/main/python/task-starter/kazoo/handlers/utils.py create mode 100644 yarn/src/main/python/task-starter/kazoo/hosts.py create mode 100644 yarn/src/main/python/task-starter/kazoo/interfaces.py create mode 100644 yarn/src/main/python/task-starter/kazoo/loggingsupport.py create mode 100644 yarn/src/main/python/task-starter/kazoo/protocol/__init__.py create mode 100644 yarn/src/main/python/task-starter/kazoo/protocol/connection.py create mode 100644 yarn/src/main/python/task-starter/kazoo/protocol/paths.py create mode 100644 yarn/src/main/python/task-starter/kazoo/protocol/serialization.py create mode 100644 yarn/src/main/python/task-starter/kazoo/protocol/states.py create mode 100644 yarn/src/main/python/task-starter/kazoo/python2atexit.py create mode 100644 yarn/src/main/python/task-starter/kazoo/recipe/__init__.py create mode 100644 yarn/src/main/python/task-starter/kazoo/recipe/barrier.py create mode 100644 yarn/src/main/python/task-starter/kazoo/recipe/counter.py create mode 100644 yarn/src/main/python/task-starter/kazoo/recipe/election.py create mode 100644 yarn/src/main/python/task-starter/kazoo/recipe/lease.py create mode 100644 yarn/src/main/python/task-starter/kazoo/recipe/lock.py create mode 100644 yarn/src/main/python/task-starter/kazoo/recipe/partitioner.py create mode 100644 yarn/src/main/python/task-starter/kazoo/recipe/party.py create mode 100644 yarn/src/main/python/task-starter/kazoo/recipe/queue.py create mode 100644 yarn/src/main/python/task-starter/kazoo/recipe/watchers.py create mode 100644 yarn/src/main/python/task-starter/kazoo/retry.py create mode 100644 yarn/src/main/python/task-starter/kazoo/security.py create mode 100644 yarn/src/main/python/task-starter/kazoo/testing/__init__.py create mode 100644 yarn/src/main/python/task-starter/kazoo/testing/common.py create mode 100644 yarn/src/main/python/task-starter/kazoo/testing/harness.py create mode 100644 yarn/src/main/python/task-starter/wrapper/__init__.py create mode 100644 yarn/src/main/python/task-starter/wrapper/__main__.py create mode 100644 yarn/src/test/java/org/tensorflow/hadoop/yarn/EmbeddedZKServer.java create mode 100644 yarn/src/test/java/org/tensorflow/hadoop/yarn/TestTF.java create mode 100644 yarn/src/test/resources/env_check.sh create mode 100644 yarn/src/test/resources/log4j.properties create mode 100644 yarn/src/test/resources/test.sh create mode 100644 yarn/src/test/resources/test_task.sh create mode 100644 yarn/src/test/resources/yarn-site.xml create mode 100644 yarn/task-starter-assembly.xml diff --git a/yarn/README.md b/yarn/README.md new file mode 100644 index 00000000..c5873be3 --- /dev/null +++ b/yarn/README.md @@ -0,0 +1,151 @@ +# TensorFlow launcher for Apache Hadoop YARN + +This project implements a [TensorFlow](http://www.tensorflow.org/) session +launcher for [Apache Hadoop YARN](http://hadoop.apache.org/docs/current/hadoop-yarn/hadoop-yarn-site/YARN.html), +such that users can utilize resources in a YARN cluster. It can support both +local and distributed TensorFlow application. + +## Prerequisites + +1. Apache Hadoop YARN +2. Zookeeper +2. Python 2.6+ +3. TensorFlow + related packages +4. Docker [optional] + +In particular, TernsorFlow and its necessary packages must be either +pre-installed on nodes in the YARN cluster or be available as a Docker image +accessible from those nodes. + +## Build + +```sh +mvn clean package +``` + +Configure +--------- + +Configure Apache Hadoop YARN cluster with Registry/Zookeeper enabled. + +## Examples +Tasks are submitted using `ytf-submit` script. + +``` +ytf-submit [OPTIONS] -r +``` + +`task_command` is the command to be execute for each of the task of the session. +The two environment variables, `DTF_TASK_JOB_NAME` and `DTF_TASK_INDEX`, will be +set before the task is executed. `cluster_requirement` is a comma separated list +of job names and the number of the instances for that job, with this format: +`:,:, ...`. + +### Simple task submission + +Let's execute a session with 2 x *Parameter Servers (ps)* and 4 x *Workers*. +Assume task program, input data, and output train, all reside in +`/home/user1/mnist` and is accessible to every node. + +```sh +$ ytf-submit -r "ps:1,worker:4" \ +'python /home/user1/mnist/mnist.py \ +--job_name ${DTF_TASK_JOB_NAME} --task_index ${DTF_TASK_INDEX} \ +--ps_hosts ${DTF_PS_HOSTS} --worker_hosts ${DTF_WORKER_HOSTS} \ +--data_dir /home/user1/mnist/data --train_dir /home/user1/mnist/train' +``` + +### Enabling TensorBoard +TensorBoard is enabled by `--tensorboard` or `-t`. The address of TensorBoard is +available at **Tracking URL** section of the submitted aplication in Apache YARN +Resource Manager web interface. For using TensorBoard, output path must be +specified by `--output` or `-o`. `DTF_OUTPUT_PATH` environment variable wil be +set and can be used in `task_command`. Similarly, input path can be passed to +`ytf-submit` and will be available as `DTF_INPUT_PATH`. + +```sh +$ ytf-submit --tensorboard \ +-i /home/user1/mnist/data -o /home/user1/mnist/train10 -r "ps:1,worker:2" \ +'python /home/user1/mnist/mnist.py \ +--job_name ${DTF_TASK_JOB_NAME} --task_index ${DTF_TASK_INDEX} \ +--ps_hosts ${DTF_PS_HOSTS} --worker_hosts ${DTF_WORKER_HOSTS} \ +--data_dir ${DTF_INPUT_PATH} --train_dir ${DTF_OUTPUT_PATH}' +``` + +### Passing the script file + +The training code itself can be passed to `ytf-submit`. The code will be copied +to HDFS and will be available at execution time. The path to the training code +will be available as `DTF_TASK_SCRIPT` environment variable. + +```sh +$ ytf-submit --tensorboard \ +-i /home/user1/mnist/data -o /home/user1/mnist/train10 -r "ps:1,worker:2" \ +-s /home/user1/mnist/mnist.py \ +'python ${DTF_TASK_SCRIPT} \ +--job_name ${DTF_TASK_JOB_NAME} --task_index ${DTF_TASK_INDEX} \ +--ps_hosts ${DTF_PS_HOSTS} --worker_hosts ${DTF_WORKER_HOSTS} \ +--data_dir ${DTF_INPUT_PATH} --train_dir ${DTF_OUTPUT_PATH}' +``` + +### Using HDFS paths +Input and output paths can be HDFS paths. + +```sh +$ ytf-submit --tensorboard \ +-i hdfs://users/user1/mnist/data -o hdfs://users/user1/mnist/train10 +-r "ps:1,worker:2" -s /home/user1/mnist/mnist.py \ +'python ${DTF_TASK_SCRIPT} \ +--job_name ${DTF_TASK_JOB_NAME} --task_index ${DTF_TASK_INDEX} \ +--ps_hosts ${DTF_PS_HOSTS} --worker_hosts ${DTF_WORKER_HOSTS} \ +--data_dir ${DTF_INPUT_PATH} --train_dir ${DTF_OUTPUT_PATH}' +``` + +### Using Docker +To execute the tasks as a Docker container, pass the Docker image name using +`--docker_image `. The docker image is required to be accesible on +the execution host. In addition to variables in **TASK EXECUTION ENVIRONMENT**, +the following paths are mounted in the container. + +- `HADOOP_HOME`, `HADOOP_CONF_DIR`, `JAVA_HOME` +- `DTF_INPUT_PATH` and `DTF_OUT_PATH` if they are not hdfs path. + +## TASK EXECUTION ENVIRONMENT + + The user specified `task_command` will be executed as a YARN container + allocated to the session. The following environment variables will be + set for the `task_command` to consume. +- `DTF_TASK_SCRIPT`: + + Name of file which contains the content of the `script_file` specified + during submission. + +- `DTF_INPUT_PATH`: + + Input path specified during submission. + +- `DTF_OUTPUT_PATH`: + + Output path specified during submission. + +- `DTF_{JOBNAME}_HOSTS`: + + Variable with a list of host (and port) allocated to the job with name + `{JOBNAME}`. + + - Format: "host1:port1,host2:port2,..." + + The number of host:port in the list should match one specified in + `cluster-requirement`. For example, `DTF_PS_HOSTS` and `DTF_WORKER_HOSTS` + would be commonly used for PS and WORKER jobs. + +- `DTF_TASK_JOB_NAME`: + + Name of job this task is assigned to. See also `DTF_TASK_INDEX`. + +- `DTF_TASK_INDEX`: + + Index of the job this task is assigned to. The tuple of `DTF_TASK_JOB_NAME`, + and `DTF_TASK_INDEX` can also be used to cross reference with + `DTF_{JOBNAME}_HOSTS`. For example, to get the dynamic port allocated to + this task. diff --git a/yarn/bin/README.md b/yarn/bin/README.md new file mode 100644 index 00000000..8c56d812 --- /dev/null +++ b/yarn/bin/README.md @@ -0,0 +1,145 @@ +Submit Command Line +------------------- + +```sh +% ytf-submit -h +NAME + ytf-submit - Submit a TensorFlow session to Apache Hadoop YARN + + This tool submits a YARN application master, resposible to allocate + required resources, and execute corresponding tasks. + +SYNOPSIS + Usage: ./ytf-submit [OPTIONS] -r + +DESCRIPTION + task_command + The command to be execute for each of the task of the session. The two + environment variables DTF_TASK_JOB_NAME and DTF_TASK_INDEX will be set + before the task is executed. + See aslo TASK EXECUTION ENVIRONMENT + + -r, --cluster_requirement + Specify cluster requiement for the session. + Format: :,:,... + Example: "ps:2,worker:4" + See also TASK EXECUTION ENVIRONMENT + + Additional options: + + -c, --task_vcores + General form to specify number of vcores required by each of the task. + DEFAULT=1 + + -c, --task_vcores : + **NOT IMPLEMENTED YET** + Job-level form to specify number of vcores required by tasks in specific + job. Overrides "general" form. + + -c, --task_vcores []: + **NOT IMPLEMENTED YET** + Task-level form to specify number of vcores required by a specific task. + Overrides both "job-level" and "general" form. + + -m, --task_memory + General form to specify amount of memory required by each of task; with + unit in MB. DEFAULT=8192 + + -m, --task_memory : + **NOT IMPLEMENTED YET** + Job-level form to specify amount of memory required by tasks in specific + job. Overrides "general" form. + + -m, --task_memory [ + **NOT IMPLEMENTED YET** + Task-level form to specify amount of memory required by a specific task. + Overrides both "job-level" and "general" form. + + -i, --input input_path + Input path, this variable is not interpreted by YARN-DTF at the + momement, it serve as a convenience. Its value will be set as + environment variable {DTF_INPUT_PATH} in tasks execution environment. + DEFAULT= + + -o, --output + Output path, this variable is not interpreted by YARN-DTF at the + momement, it serve as a convenience. Its value will be set as + environment variable {DTF_OUTPUT_PATH} in tasks execution environment. + + However, when TensorBoard integration is enabled, this option becomes + mandatory. See also --tensorborad option. + + Its value will be set as environment variable {DTF_OUTPUT_PATH} in tasks + execution environment. + + -s, --script + A local script file to be transfer to tasks execution environment, where + a file named by variable {DTF_TASK_SCRIPT} will contain the content of + the script file. For example, if the script is a Python script, + the execution command can be written as "python ${DTF_TASK_SCRIPT} ..." + + -t, --tensorboard + Enable TensorBoard integration. When enabled, YARN-DTF will start an + additional YARN container as tensorboard with output path specified in + --output option. DEFAULT=disabled + + --docker_image + Enable tasks to be executed as a docker container. The docker image is + required to be accesible on the execution host. In addition to variables + in TASK EXECUTION ENVIRONMENT, the following paths are mounted in + container to the execution host. + + HADOOP_HOME, HADOOP_CONF_DIR, JAVA_HOME. + DTF_INPUT_PATH and DTF_OUT_PATH if they are not hdfs path. + + -q, --queue + Specify which YARN queue to submit this session to. + DEFAULT=default + + -n, --name + Name of this session, will be used as name of YARN application. + DEFAULT=TensorFlow + + --client + **NOT IMPLEMENTED YET** + Specify if an additional task should be started on locally. This + would be useful if user interaction is required. + + This task will same execution environment as the rest of the tasks, + and will be assigned with DTF_TASK_JOB_NAME=client and DTF_TASK_INDEX=0; + however, will not be part of the TensorFlow cluster and dynamic port + allocation would not apply. + +TASK EXECUTION ENVIRONMENT + + The user specified 'task_command' will be executed as a YARN container + allocated to the session. The following environment variables will be + set for the 'task_command' to consume. + + DTF_TASK_SCRIPT: + Name of file which contains the content of the 'script_file' specified + during submission. + + DTF_INPUT_PATH: + Input path specified during submission. + + DTF_OUTPUT_PATH: + Output path specified during submission. + + DTF_{JOBNAME}_HOSTS: + Variable with a list of host (and port) allocated to the job with name + {JOBNAME}. + Format: "host1:port1,host2:port2,..." + The number of host:port in the list should match one specified in + "cluster-requirement". For example, DTF_PS_HOSTS and DTF_WORKER_HOSTS + would be commonly used for PS and WORKER jobs. + + DTF_TASK_JOB_NAME: + Name of job this task is assigned to. See also DTF_TASK_INDEX. + + DTF_TASK_INDEX + Index of the job this task is assigned to. + The tuple of DTF_TASK_JOB_NAME, and DTF_TASK_INDEX can also be used + to cross reference with DTF_{JOBNAME}_HOSTS. For example, to get the + dynamic port allocated to this task. +``` diff --git a/yarn/bin/ytf-submit b/yarn/bin/ytf-submit new file mode 100644 index 00000000..7d90ebdf --- /dev/null +++ b/yarn/bin/ytf-submit @@ -0,0 +1,284 @@ +#!/bin/bash + +MYDIR=$(dirname $0) + +JAR=hadoop-yarn-tensorflow-2.7.2.jar +JAR_PATH=${MYDIR}/../target +CLIENT_JAR=${JAR_PATH}/${JAR} +AM_JAR=${JAR_PATH}/${JAR} + +YARN=yarn +YARN_OPTS="jar ${CLIENT_JAR} -jar ${AM_JAR}" + +DEF_TASK_VCORES=1 +DEF_TASK_MEMORY=8192 +DEF_INPUT_PATH="" +DEF_YARN_QUEUE="default" +DEF_NAME="TensorFlow" + +printUsage() { + cat < + + task_command + The command to be execute for each of the task of the session. The two + environment variables DTF_JOB_NAME and DTF_TASK_INDEX will be set + before the task is executed. + See aslo TASK EXECUTION ENVIRONMENT + + -r, --cluster_requirement + Specify cluster requiement for the session. + Format: :,:,... + Example: "ps:2,worker:4" + See also TASK EXECUTION ENVIRONMENT + +Use "$0 --help" to see additional information +EOH +} + +printManual() { + cat < + +DESCRIPTION + task_command + The command to be execute for each of the task of the session. The two + environment variables DTF_TASK_JOB_NAME and DTF_TASK_INDEX will be set + before the task is executed. + See aslo TASK EXECUTION ENVIRONMENT + + -r, --cluster_requirement + Specify cluster requiement for the session. + Format: :,:,... + Example: "ps:2,worker:4" + See also TASK EXECUTION ENVIRONMENT + + Additional options: + + -c, --task_vcores + General form to specify number of vcores required by each of the task. + DEFAULT=${DEF_TASK_VCORES} + + -c, --task_vcores : + **NOT IMPLEMENTED YET** + Job-level form to specify number of vcores required by tasks in specific + job. Overrides "general" form. + + -c, --task_vcores []: + **NOT IMPLEMENTED YET** + Task-level form to specify number of vcores required by a specific task. + Overrides both "job-level" and "general" form. + + -m, --task_memory + General form to specify amount of memory required by each of task; with + unit in MB. DEFAULT=${DEF_TASK_MEMORY} + + -m, --task_memory : + **NOT IMPLEMENTED YET** + Job-level form to specify amount of memory required by tasks in specific + job. Overrides "general" form. + + -m, --task_memory [ + **NOT IMPLEMENTED YET** + Task-level form to specify amount of memory required by a specific task. + Overrides both "job-level" and "general" form. + + -i, --input input_path + Input path, this variable is not interpreted by YARN-DTF at the + momement, it serve as a convenience. Its value will be set as + environment variable {DTF_INPUT_PATH} in tasks execution environment. + DEFAULT=${DEF_INPUT_PATH} + + -o, --output + Output path, this variable is not interpreted by YARN-DTF at the + momement, it serve as a convenience. Its value will be set as + environment variable {DTF_OUTPUT_PATH} in tasks execution environment. + + However, when TensorBoard integration is enabled, this option becomes + mandatory. See also --tensorborad option. + + Its value will be set as environment variable {DTF_OUTPUT_PATH} in tasks + execution environment. + + -s, --script + A local script file to be transfer to tasks execution environment, where + a file named by variable {DTF_TASK_SCRIPT} will contain the content of + the script file. For example, if the script is a Python script, + the execution command can be written as "python \${DTF_TASK_SCRIPT} ..." + + -t, --tensorboard + Enable TensorBoard integration. When enabled, YARN-DTF will start an + additional YARN container as tensorboard with output path specified in + --output option. DEFAULT=disabled + + --docker_image + Enable tasks to be executed as a docker container. The docker image is + required to be accesible on the execution host. In addition to variables + in TASK EXECUTION ENVIRONMENT, the following paths are mounted in + container to the execution host. + + HADOOP_HOME, HADOOP_CONF_DIR, JAVA_HOME. + DTF_INPUT_PATH and DTF_OUT_PATH if they are not hdfs path. + + -q, --queue + Specify which YARN queue to submit this session to. + DEFAULT=${DEF_YARN_QUEUE} + + -n, --name + Name of this session, will be used as name of YARN application. + DEFAULT=${DEF_NAME} + + --client + **NOT IMPLEMENTED YET** + Specify if an additional task should be started on locally. This + would be useful if user interaction is required. + + This task will same execution environment as the rest of the tasks, + and will be assigned with DTF_TASK_JOB_NAME=client and DTF_TASK_INDEX=0; + however, will not be part of the TensorFlow cluster and dynamic port + allocation would not apply. + +TASK EXECUTION ENVIRONMENT + + The user specified 'task_command' will be executed as a YARN container + allocated to the session. The following environment variables will be + set for the 'task_command' to consume. + + DTF_TASK_SCRIPT: + Name of file which contains the content of the 'script_file' specified + during submission. + + DTF_INPUT_PATH: + Input path specified during submission. + + DTF_OUTPUT_PATH: + Output path specified during submission. + + DTF_{JOBNAME}_HOSTS: + Variable with a list of host (and port) allocated to the job with name + {JOBNAME}. + Format: "host1:port1,host2:port2,..." + The number of host:port in the list should match one specified in + "cluster-requirement". For example, DTF_PS_HOSTS and DTF_WORKER_HOSTS + would be commonly used for PS and WORKER jobs. + + DTF_TASK_JOB_NAME: + Name of job this task is assigned to. See also DTF_TASK_INDEX. + + DTF_TASK_INDEX + Index of the job this task is assigned to. + The tuple of DTF_TASK_JOB_NAME, and DTF_TASK_INDEX can also be used + to cross reference with DTF_{JOBNAME}_HOSTS. For example, to get the + dynamic port allocated to this task. + +EOH + return 0 +} + +#default values +YARN_QUEUE=${DEF_YARN_QUEUE} +NAME=${DEF_NAME} +TENSORBOARD="" #disabled +DOCKER_IMAGE="" +TASK_VCORES=${DEF_TASK_VCORES} +TASK_MEMORY=${DEF_TASK_MEMORY} +INPUT_PATH=${DEF_INPUT_PATH} +TASK_SCRIPT="" +ARGS="" + +while [[ $# -ge 1 ]] +do + key="$1" + case $key in + -t|--tensorboard) + TENSORBOARD=-enable_tensorboard + ;; + -c|--task_vcores) + TASK_VCORES="$2" + shift # past argument + ;; + -m|--task_memory) + TASK_MEMORY="$2" + shift # past argument + ;; + -r|--cluster_requirement) + JOBS_TASKS="$2" + shift # past argument + ;; + -i|--input) + INPUT_PATH="$2" + shift # past argument + ;; + -o|--output) + OUTPUT_PATH="$2" + shift # past argument + ;; + -s|--script) + TASK_SCRIPT="$2" + shift # past argument + ;; + --docker_image) + DOCKER_IMAGE="$2" + shift # past argument + ;; + -q|--queue) + YARN_QUEUE="$2" + shift # past argument + ;; + -n|--name) + NAME="$2" + shift # past argument + ;; + -h|--help|--manual) + printManual + exit + ;; + *) + ARGS="${ARGS} ${key}" + ;; + esac + shift # past argument or value +done + +TASK_CMD="${ARGS}" + +if [ "${TASK_CMD}" == "" ]; then + echo "Missing mandatory 'task_command'" + printUsage + exit 1 +fi + +if [ "${JOBS_TASKS}" == "" ]; then + echo "Missing mandatory option -r,--cluster_requirement" + printUsage + exit 1 +fi + +if [ "${TENSORBOARD}" != "" -a "${OUTPUT_PATH}" == "" ]; then + echo "Mising mandatory option -o,--output when using -t,--tensorboard" + printUsage + exit 1 +fi + +DOCKER_OPT= + +exec ${YARN} ${YARN_OPTS} \ + -queue "${YARN_QUEUE}" \ + -appname "${NAME}" \ + ${TENSORBOARD} \ + --docker_image "${DOCKER_IMAGE}" \ + -container_vcores "${TASK_VCORES}" \ + -container_memory "${TASK_MEMORY}" \ + -num_containers "${JOBS_TASKS}" \ + -input_path "${INPUT_PATH}" \ + -output_path "${OUTPUT_PATH}" \ + -task_script "${TASK_SCRIPT}" \ + -task_cmd "${TASK_CMD}" + diff --git a/yarn/example/README b/yarn/example/README new file mode 100644 index 00000000..b7fc504a --- /dev/null +++ b/yarn/example/README @@ -0,0 +1,38 @@ +Create TFRecords file: + +- Download "convert_to_records.py", a program used to download and convert MNIST dataset to TFRecords (TensorFlow) format + +$ curl -k https://raw.githubusercontent.com/tensorflow/tensorflow/r0.11/tensorflow/examples/how_tos/reading_data/convert_to_records.py > convert_to_records.py + +- Download dataset and convert + +$ python convert_to_records.py +Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes. +Extracting /tmp/data/train-images-idx3-ubyte.gz +Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes. +Extracting /tmp/data/train-labels-idx1-ubyte.gz +Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes. +Extracting /tmp/data/t10k-images-idx3-ubyte.gz +Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes. +Extracting /tmp/data/t10k-labels-idx1-ubyte.gz +Writing /tmp/data/train.tfrecords +Writing /tmp/data/validation.tfrecords +Writing /tmp/data/test.tfrecords + +- Locate files under /tmp/data and copy to a shared location + +$ ls -l /tmp/data/ +total 72252 +-rw-rw-r-- 1 stack stack 1648877 Feb 3 15:17 t10k-images-idx3-ubyte.gz +-rw-rw-r-- 1 stack stack 4542 Feb 3 15:17 t10k-labels-idx1-ubyte.gz +-rw-rw-r-- 1 stack stack 8910000 Feb 3 15:19 test.tfrecords +-rw-rw-r-- 1 stack stack 9912422 Feb 3 15:17 train-images-idx3-ubyte.gz +-rw-rw-r-- 1 stack stack 28881 Feb 3 15:17 train-labels-idx1-ubyte.gz +-rw-rw-r-- 1 stack stack 49005000 Feb 3 15:19 train.tfrecords +-rw-rw-r-- 1 stack stack 4455000 Feb 3 15:19 validation.tfrecords + +$ mkdir -p mnist +$ cp -r /tmp/data $HOME/mnist/ +$ cp mnist.py $HOME/mnist/ + +- Optionally, upload data to HDFS and use hdfs://... to specify input data location diff --git a/yarn/example/mnist.py b/yarn/example/mnist.py new file mode 100644 index 00000000..75758517 --- /dev/null +++ b/yarn/example/mnist.py @@ -0,0 +1,196 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# pylint: disable=bad-indentation,invalid-name +from __future__ import print_function + +import os +import sys + +import tensorflow as tf +from tensorflow.examples.tutorials.mnist import mnist +from tensorflow.python.framework import ops + +flags = tf.app.flags + +# Flags for configuring the task +flags.DEFINE_string("job_name", None, "job name: worker or ps") +flags.DEFINE_integer("task_index", 0, + "Worker task index, should be >= 0. task_index=0 is " + "the chief worker task the performs the variable " + "initialization") +flags.DEFINE_string("ps_hosts", "", + "Comma-separated list of hostname:port pairs") +flags.DEFINE_string("worker_hosts", "", + "Comma-separated list of hostname:port pairs") + +# Training related flags +flags.DEFINE_string("data_dir", None, + "Directory where the mnist data is stored") +flags.DEFINE_string("train_dir", None, + "Directory for storing the checkpoints") +flags.DEFINE_integer("hidden1", 128, + "Number of units in the 1st hidden layer of the NN") +flags.DEFINE_integer("hidden2", 128, + "Number of units in the 2nd hidden layer of the NN") +flags.DEFINE_integer("batch_size", 100, "Training batch size") +flags.DEFINE_float("learning_rate", 0.01, "Learning rate") +flags.DEFINE_integer("steps", 10000, "Training batch size") + +FLAGS = flags.FLAGS +TRAIN_FILE = "train.tfrecords" + +def get_global_step(graph=None): + "Extracts the global step from the model" + graph = ops.get_default_graph() if graph is None else graph + global_step_tensor = None + global_step_tensors = graph.get_collection(ops.GraphKeys.GLOBAL_STEP) + if len(global_step_tensors) == 1: + global_step_tensor = global_step_tensors[0] + elif not global_step_tensors: + try: + global_step_tensor = graph.get_tensor_by_name('global_step:0') + except KeyError: + return None + return global_step_tensor + +def read_and_decode(filename_queue): + reader = tf.TFRecordReader() + _, serialized_example = reader.read(filename_queue) + features = tf.parse_single_example( + serialized_example, + # Defaults are not specified since both keys are required. + features={ + 'image_raw': tf.FixedLenFeature([], tf.string), + 'label': tf.FixedLenFeature([], tf.int64), + }) + + # Convert from a scalar string tensor (whose single string has + # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape + # [mnist.IMAGE_PIXELS]. + image = tf.decode_raw(features['image_raw'], tf.uint8) + image.set_shape([mnist.IMAGE_PIXELS]) + + # OPTIONAL: Could reshape into a 28x28 image and apply distortions + # here. Since we are not applying any distortions in this + # example, and the next step expects the image to be flattened + # into a vector, we don't bother. + + # Convert from [0, 255] -> [-0.5, 0.5] floats. + image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 + + # Convert label from a scalar uint8 tensor to an int32 scalar. + label = tf.cast(features['label'], tf.int32) + + return image, label + + +def inputs(batch_size): + """Reads input data. + + Args: + batch_size: Number of examples per returned batch. + + Returns: + A tuple (images, labels), where: + * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS] + in the range [-0.5, 0.5]. + * labels is an int32 tensor with shape [batch_size] with the true label, + a number in the range [0, mnist.NUM_CLASSES). + """ + filename = os.path.join(FLAGS.data_dir, TRAIN_FILE) + + with tf.name_scope('input'): + filename_queue = tf.train.string_input_producer([filename]) + + # Even when reading in multiple threads, share the filename + # queue. + image, label = read_and_decode(filename_queue) + + # Shuffle the examples and collect them into batch_size batches. + # (Internally uses a RandomShuffleQueue.) + # We run this in two threads to avoid being a bottleneck. + images, sparse_labels = tf.train.shuffle_batch( + [image, label], batch_size=batch_size, num_threads=2, + capacity=1000 + 3 * batch_size, + # Ensures a minimum amount of shuffling of examples. + min_after_dequeue=1000) + + return images, sparse_labels + + +def device_and_target(): + # If FLAGS.job_name is not set, we're running single-machine TensorFlow. + # Don't set a device. + if FLAGS.job_name is None: + print("Running single-machine training") + return (None, "") + + # Otherwise we're running distributed TensorFlow. + print("Running distributed training") + if FLAGS.task_index is None or FLAGS.task_index == "": + raise ValueError("Must specify an explicit `task_index`") + if FLAGS.ps_hosts is None or FLAGS.ps_hosts == "": + raise ValueError("Must specify an explicit `ps_hosts`") + if FLAGS.worker_hosts is None or FLAGS.worker_hosts == "": + raise ValueError("Must specify an explicit `worker_hosts`") + + cluster_spec = tf.train.ClusterSpec({ + "ps": FLAGS.ps_hosts.split(","), + "worker": FLAGS.worker_hosts.split(","), + }) + server = tf.train.Server( + cluster_spec, job_name=FLAGS.job_name, task_index=FLAGS.task_index) + if FLAGS.job_name == "ps": + server.join() + + worker_device = "/job:worker/task:{}".format(FLAGS.task_index) + # The device setter will automatically place Variables ops on separate + # parameter servers (ps). The non-Variable ops will be placed on the workers. + return ( + tf.train.replica_device_setter( + worker_device=worker_device, + cluster=cluster_spec), + server.target, + ) + + +def main(unused_argv): + if FLAGS.data_dir is None or FLAGS.data_dir == "": + raise ValueError("Must specify an explicit `data_dir`") + if FLAGS.train_dir is None or FLAGS.train_dir == "": + raise ValueError("Must specify an explicit `train_dir`") + + device, target = device_and_target() + + with tf.device(device): + images, labels = inputs(FLAGS.batch_size) + logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2) + loss = mnist.loss(logits, labels) + train_op = mnist.training(loss, FLAGS.learning_rate) + evaluation_op = mnist.evaluation(logits, labels) + + sv = tf.train.Supervisor(logdir=FLAGS.train_dir, is_chief=(FLAGS.task_index == 0)) + + step = 0 + with sv.managed_session(target) as sess: + global_step_op = get_global_step(sess.graph) + while not sv.should_stop() and step < FLAGS.steps: + _, step, evaluation = sess.run([train_op, global_step_op, evaluation_op]) + print("Global step %d, evaluation: %f" % (int(step), float(evaluation))) + sys.stdout.flush() + +if __name__ == "__main__": + tf.app.run() diff --git a/yarn/pom.xml b/yarn/pom.xml new file mode 100644 index 00000000..92a6a59e --- /dev/null +++ b/yarn/pom.xml @@ -0,0 +1,187 @@ + + + + + hadoop-yarn-applications + org.apache.hadoop + 2.7.2 + + 4.0.0 + org.tensorflow + hadoop-yarn-tensorflow + 2.7.2 + hadoop-yarn-tensorflow + + + + org.apache.hadoop + hadoop-common + provided + + + junit + junit + test + + + log4j + log4j + + + commons-lang + commons-lang + + + com.google.guava + guava + + + commons-logging + commons-logging + + + commons-cli + commons-cli + + + commons-io + commons-io + + + org.apache.hadoop + hadoop-annotations + + + org.apache.hadoop + hadoop-yarn-api + + + org.apache.hadoop + hadoop-yarn-common + + + org.apache.hadoop + hadoop-yarn-client + + + org.apache.hadoop + hadoop-yarn-registry + + + javax.servlet + servlet-api + 2.5 + + + + + org.apache.hadoop + hadoop-common + test-jar + test + + + org.apache.hadoop + hadoop-yarn-server-resourcemanager + test + + + org.apache.hadoop + hadoop-yarn-server-applicationhistoryservice + 2.7.2 + test + + + org.apache.hadoop + hadoop-yarn-server-tests + test-jar + test + + + org.apache.zookeeper + zookeeper + 3.4.6 + test + + + + + + + maven-jar-plugin + + + + jar + + + test-compile + + + + + + org.tensorflow.hadoop.yarn.Client + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${java.home} + ${env.HADOOP_HOME} + + + + + maven-assembly-plugin + 3.0.0 + + + archive-task-starter + validate + + single + + + false + + task-starter-assembly.xml + + gnu + task-starter + + + + + + + + ${project.build.directory} + + task-starter.zip + + + + + + diff --git a/yarn/src/main/java/org/tensorflow/hadoop/yarn/ApplicationMaster.java b/yarn/src/main/java/org/tensorflow/hadoop/yarn/ApplicationMaster.java new file mode 100644 index 00000000..2bdbafa5 --- /dev/null +++ b/yarn/src/main/java/org/tensorflow/hadoop/yarn/ApplicationMaster.java @@ -0,0 +1,1420 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.hadoop.yarn; + +/** + * TensorFlow launcher for YARN + * + * Modified from YARN sample application: DistributedShell. + */ + +import java.io.BufferedReader; +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringReader; +import java.lang.reflect.UndeclaredThrowableException; +import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.security.PrivilegedExceptionAction; +import java.util.*; +import java.util.Map.Entry; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.GnuParser; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.PathNotFoundException; +import org.apache.hadoop.http.HttpServer2; +import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.hadoop.io.IOUtils; +import org.apache.hadoop.net.NetUtils; +import org.apache.hadoop.registry.client.api.BindFlags; +import org.apache.hadoop.registry.client.api.RegistryConstants; +import org.apache.hadoop.registry.client.api.RegistryOperations; +import org.apache.hadoop.registry.client.api.RegistryOperationsFactory; +import org.apache.hadoop.registry.client.types.ServiceRecord; +import org.apache.hadoop.registry.client.types.yarn.PersistencePolicies; +import org.apache.hadoop.registry.client.types.yarn.YarnRegistryAttributes; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.util.ExitUtil; +import org.apache.hadoop.util.Shell; +import org.apache.hadoop.yarn.api.ApplicationConstants; +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment; +import org.apache.hadoop.yarn.api.ContainerManagementProtocol; +import org.apache.hadoop.yarn.api.protocolrecords.RegisterApplicationMasterResponse; +import org.apache.hadoop.yarn.api.records.ApplicationAttemptId; +import org.apache.hadoop.yarn.api.records.Container; +import org.apache.hadoop.yarn.api.records.ContainerExitStatus; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.api.records.ContainerLaunchContext; +import org.apache.hadoop.yarn.api.records.ContainerState; +import org.apache.hadoop.yarn.api.records.ContainerStatus; +import org.apache.hadoop.yarn.api.records.FinalApplicationStatus; +import org.apache.hadoop.yarn.api.records.LocalResource; +import org.apache.hadoop.yarn.api.records.NodeReport; +import org.apache.hadoop.yarn.api.records.Priority; +import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.hadoop.yarn.api.records.timeline.TimelineEntity; +import org.apache.hadoop.yarn.api.records.timeline.TimelineEvent; +import org.apache.hadoop.yarn.api.records.timeline.TimelinePutResponse; +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest; +import org.apache.hadoop.yarn.client.api.TimelineClient; +import org.apache.hadoop.yarn.client.api.async.AMRMClientAsync; +import org.apache.hadoop.yarn.client.api.async.NMClientAsync; +import org.apache.hadoop.yarn.client.api.async.impl.NMClientAsyncImpl; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.security.AMRMTokenIdentifier; +import org.apache.hadoop.yarn.util.ConverterUtils; +import org.apache.log4j.LogManager; + +import com.google.common.annotations.VisibleForTesting; + +@InterfaceAudience.Public +@InterfaceStability.Unstable +public class ApplicationMaster +{ + private static final Log LOG = LogFactory.getLog(ApplicationMaster.class); + // used by taskStarter to pass back info to AM + private static final String TASK_JOB_NAME = "task_job_name"; + private static final String TASK_JOB_INDEX = "task_job_index"; + private static final String TASK_PORT = "task_port"; + + // Configuration + private Configuration conf; + + // Handle to communicate with the Resource Manager + private AMRMClientAsync amRMClient; + + // In both secure and non-secure modes, this points to the job-submitter. + @VisibleForTesting + UserGroupInformation appSubmitterUgi; + + // Handle to communicate with the Node Manager + private NMClientAsync nmClientAsync; + // Listen to process the response from the Node Manager + private NMCallbackHandler containerListener; + + // Application Attempt Id ( combination of attemptId and fail count ) + @VisibleForTesting + private ApplicationAttemptId appAttemptID; + + private String appIDString; + + // For status update for clients - yet to be implemented + // Hostname of the container + private String appMasterHostname = ""; + // Port on which the app master listens for status updates from clients + private int appMasterRpcPort = -1; + // Tracking url to which app master publishes info for clients to monitor + private String appMasterTrackingUrl = ""; + + // App Master configuration + // No. of containers to run shell command on + @VisibleForTesting + private int numTotalContainers = 1; + // Memory to request for the container on which the shell command will run + private int containerMemory = 10; + // VirtualCores to request for the container on which the shell command will + // run + private int containerVirtualCores = 1; + // Priority of the request + private int requestPriority; + + // Counter for completed containers ( complete denotes successful or failed + // ) + private AtomicInteger numCompletedContainers = new AtomicInteger(); + // Allocated container count so that we know how many containers has the RM + // allocated to us + @VisibleForTesting + private AtomicInteger numAllocatedContainers = new AtomicInteger(); + // Count of failed containers + private AtomicInteger numFailedContainers = new AtomicInteger(); + // Count of containers already requested from the RM + // Needed as once requested, we should not request for containers again. + // Only request for more if the original requirement changes. + @VisibleForTesting + private AtomicInteger numRequestedContainers = new AtomicInteger(); + + // Shell command to be executed + private String taskCommand = ""; + // Args to be passed to the shell command + private String taskArgs = ""; + // Env variables to be setup for the shell command + private Map shellEnv = new ConcurrentHashMap<>(); + + // Location of shell script ( obtained from info set in env ) + // Shell script path in fs + private String scriptPath = ""; + // Timestamp needed for creating a local resource + private long shellScriptPathTimestamp = 0; + // File length needed for local resource + private long shellScriptPathLen = 0; + + private Map localResources = new ConcurrentHashMap<>(); + + // Timeline domain ID + private String domainId = null; + + // Hardcoded path to custom log_properties + private static final String log4jPath = "log4j.properties"; + + private static final String shellCommandPath = "shellCommands"; + private static final String shellArgsPath = "shellArgs"; + + private volatile boolean done; + + private ByteBuffer allTokens; + + // Launch threads + private List launchThreads = new ArrayList<>(); + + // Timeline Client + @VisibleForTesting + TimelineClient timelineClient; + + private TFSession tfSession; + + private String appSubmitterUserName; + + private boolean tensorboardFlag = false; + + /** + * @param args + * Command line args + */ + public static void main(String[] args) + { + boolean result = false; + try + { + ApplicationMaster appMaster = new ApplicationMaster(); + LOG.info("TF: Initializing ApplicationMaster"); + boolean doRun = appMaster.init(args); + if (!doRun) + { + System.exit(0); + } + + appMaster.run(); + result = appMaster.finish(); + } catch (Throwable t) + { + LOG.fatal("Error running ApplicationMaster", t); + LogManager.shutdown(); + ExitUtil.terminate(1, t); + } + + if (result) + { + LOG.info("Application Master completed successfully. exiting"); + System.exit(0); + } else + { + LOG.info("Application Master failed. exiting"); + System.exit(2); + } + } + + /** + * Dump out contents of $CWD and the environment to stdout for debugging + */ + private void dumpOutDebugInfo() + { + + LOG.info("Dump debug output"); + Map envs = System.getenv(); + for (Map.Entry env : envs.entrySet()) + { + LOG.info("System env: key=" + env.getKey() + ", val=" + + env.getValue()); + System.out.println("System env: key=" + env.getKey() + ", val=" + + env.getValue()); + } + + BufferedReader buf = null; + try + { + String lines = Shell.WINDOWS ? Shell + .execCommand("cmd", "/c", "dir") : Shell.execCommand("ls", + "-al"); + buf = new BufferedReader(new StringReader(lines)); + String line; + while ((line = buf.readLine()) != null) + { + LOG.info("System CWD content: " + line); + System.out.println("System CWD content: " + line); + } + } catch (IOException e) + { + e.printStackTrace(); + } finally + { + IOUtils.cleanup(LOG, buf); + } + } + + public ApplicationMaster() + { + // Set up the configuration + conf = new YarnConfiguration(); + } + + /** + * Parse command line options + * + * @param args + * Command line args + * @return Whether init successful and run should be invoked + * @throws ParseException + * @throws IOException + */ + public boolean init(String[] args) throws ParseException, IOException + { + Options opts = new Options(); + opts.addOption("app_attempt_id", true, + "App Attempt ID. Not to be used unless for testing purposes"); + opts.addOption("shell_env", true, + "Environment for shell script. Specified as env_key=env_val pairs"); + opts.addOption("input_path", true, + "Input path of TensorFlow tasks"); + opts.addOption("output_path", true, + "Output path of TensorFlow tasks"); + opts.addOption("container_memory", true, + "Amount of memory in MB to be requested to run the shell command"); + opts.addOption("container_vcores", true, + "Amount of virtual cores to be requested to run the shell command"); + opts.addOption("num_containers", true, + "No. of containers on which the shell command needs to be executed"); + opts.addOption("priority", true, "Application Priority. Default 0"); + opts.addOption("debug", false, "Dump out debug information"); + opts.addOption("enable_tensorboard", false, "Start TensorBoard as part of job"); + opts.addOption("docker_image", true, "Docker image for running the tasks"); + opts.addOption("appname", true, "Application Name. Default: " + TFConstants.DEFAULT_APPNAME); + + opts.addOption("help", false, "Print usage"); + CommandLine cliParser = new GnuParser().parse(opts, args); + + if (args.length == 0) + { + printUsage(opts); + throw new IllegalArgumentException( + "No args specified for application master to initialize"); + } + + // Check whether customer log4j.properties file exists + if (fileExist(log4jPath)) + { + try + { + Log4jPropertyHelper.updateLog4jConfiguration( + ApplicationMaster.class, log4jPath); + } catch (Exception e) + { + LOG.warn("Can not set up custom log4j properties. " + e); + } + } + + if (cliParser.hasOption("help")) + { + printUsage(opts); + return false; + } + + if (cliParser.hasOption("debug")) + { + dumpOutDebugInfo(); + } + + if (cliParser.hasOption("enable_tensorboard")) + { + tensorboardFlag = true; + } + + Map envs = System.getenv(); + + if (!envs.containsKey(Environment.CONTAINER_ID.name())) + { + if (cliParser.hasOption("app_attempt_id")) + { + String appIdStr = cliParser + .getOptionValue("app_attempt_id", ""); + appAttemptID = ConverterUtils.toApplicationAttemptId(appIdStr); + } else + { + throw new IllegalArgumentException( + "Application Attempt Id not set in the environment"); + } + } else + { + ContainerId containerId = ConverterUtils.toContainerId(envs + .get(Environment.CONTAINER_ID.name())); + appAttemptID = containerId.getApplicationAttemptId(); + } + + appIDString = appAttemptID.getApplicationId().toString(); + + if (!envs.containsKey(ApplicationConstants.APP_SUBMIT_TIME_ENV)) + { + throw new RuntimeException(ApplicationConstants.APP_SUBMIT_TIME_ENV + + " not set in the environment"); + } + if (!envs.containsKey(Environment.NM_HOST.name())) + { + throw new RuntimeException(Environment.NM_HOST.name() + + " not set in the environment"); + } + if (!envs.containsKey(Environment.NM_HTTP_PORT.name())) + { + throw new RuntimeException(Environment.NM_HTTP_PORT + + " not set in the environment"); + } + if (!envs.containsKey(Environment.NM_PORT.name())) + { + throw new RuntimeException(Environment.NM_PORT.name() + + " not set in the environment"); + } + + LOG.info("Application master for app" + ", appId=" + + appAttemptID.getApplicationId().getId() + + ", clustertimestamp=" + + appAttemptID.getApplicationId().getClusterTimestamp() + + ", attemptId=" + appAttemptID.getAttemptId()); + + if (!fileExist(shellCommandPath) + && envs.get(TFConstants.SCRIPTLOCATION) + .isEmpty()) + { + throw new IllegalArgumentException( + "No shell command or shell script specified to be executed by application master"); + } + + if (fileExist(shellCommandPath)) + { + taskCommand = readContent(shellCommandPath); + } + + if (fileExist(shellArgsPath)) + { + taskArgs = readContent(shellArgsPath); + } + + if (cliParser.hasOption("shell_env")) + { + String shellEnvs[] = cliParser.getOptionValues("shell_env"); + for (String env : shellEnvs) + { + env = env.trim(); + int index = env.indexOf('='); + if (index == -1) + { + shellEnv.put(env, ""); + continue; + } + String key = env.substring(0, index); + String val = ""; + if (index < (env.length() - 1)) + { + val = env.substring(index + 1); + } + shellEnv.put(key, val); + } + } + + if (envs.containsKey(TFConstants.SCRIPTLOCATION)) + { + scriptPath = envs.get(TFConstants.SCRIPTLOCATION); + + if (envs.containsKey(TFConstants.SCRIPTTIMESTAMP)) + { + shellScriptPathTimestamp = Long.parseLong(envs + .get(TFConstants.SCRIPTTIMESTAMP)); + } + if (envs.containsKey(TFConstants.SHELLSCRIPTLEN)) + { + shellScriptPathLen = Long.parseLong(envs + .get(TFConstants.SHELLSCRIPTLEN)); + } + if (!scriptPath.isEmpty() + && (shellScriptPathTimestamp <= 0 || shellScriptPathLen <= 0)) + { + LOG.error("Illegal values in env for shell script path" + + ", path=" + scriptPath + ", len=" + + shellScriptPathLen + ", timestamp=" + + shellScriptPathTimestamp); + throw new IllegalArgumentException( + "Illegal values in env for shell script path"); + } + } + + String input_path = cliParser.getOptionValue("input_path", ""); + String output_path = cliParser.getOptionValue("output_path", ""); + String dockerImage = cliParser.getOptionValue("docker_image", ""); + + String appName = cliParser.getOptionValue("appname", TFConstants.DEFAULT_APPNAME); + + String registryQuorum = conf.get(RegistryConstants.KEY_REGISTRY_ZK_QUORUM); + if (registryQuorum == null || registryQuorum.isEmpty()) + { + throw new IllegalArgumentException( + "Undefined mandatoryconfiguration <" + + RegistryConstants.KEY_REGISTRY_ZK_QUORUM + ">"); + } + + if (envs.containsKey(TFConstants.TIMELINEDOMAIN)) + { + domainId = envs.get(TFConstants.TIMELINEDOMAIN); + } + + + // default container vcores/memory/priority request + containerMemory = Integer.parseInt(cliParser.getOptionValue( + "container_memory", "10")); + containerVirtualCores = Integer.parseInt(cliParser.getOptionValue( + "container_vcores", "1")); + requestPriority = Integer.parseInt(cliParser.getOptionValue("priority", + "0")); + + String clusterReqString = cliParser.getOptionValue("num_containers", ""); + + // create TensorFlow session object + TFSession.TFSessionBuilder builder = new TFSession.TFSessionBuilder(); + builder.setClusterReqString(clusterReqString); + builder.setAppName(appName); + builder.setAppIDString(appIDString); + builder.setTaskCmd(taskCommand); + builder.setTaskArgs(taskArgs); + builder.setEnableTensorBoard(tensorboardFlag); + builder.setScriptPath(scriptPath); + builder.setInputPath(input_path); + builder.setOutputPath(output_path); + builder.setDockerImage(dockerImage); + builder.setRegistryQuorum(registryQuorum); + TFContainerRequest defaultRequest = new TFContainerRequest(containerVirtualCores, + containerMemory, + requestPriority); + builder.setDefaultContainerRequest(defaultRequest); + tfSession = builder.build(); + + // Add application wide environment variables + tfSession.setAppGlobalEnv(shellEnv); + + ArrayList requests = tfSession.getRequiredContainers(); + numTotalContainers = requests.size(); + if (numTotalContainers == 0) + { + throw new IllegalArgumentException( + "Cannot run TensorFlow with no tasks"); + } + + return true; + } + + + /** + * Helper function to print usage + * + * @param opts + * Parsed command line options + */ + private void printUsage(Options opts) + { + new HelpFormatter().printHelp("ApplicationMaster", opts); + } + + /** + * Main run function for the application master + * + * @throws YarnException + * @throws IOException + * @throws URISyntaxException + */ + @SuppressWarnings({ "unchecked" }) + public void run() throws YarnException, IOException, InterruptedException, URISyntaxException + { + LOG.info("Starting ApplicationMaster"); + + // Note: Credentials, Token, UserGroupInformation, DataOutputBuffer + // class + // are marked as LimitedPrivate + Credentials credentials = UserGroupInformation.getCurrentUser() + .getCredentials(); + DataOutputBuffer dob = new DataOutputBuffer(); + credentials.writeTokenStorageToStream(dob); + // Now remove the AM->RM token so that containers cannot access it. + Iterator> iter = credentials.getAllTokens().iterator(); + LOG.info("Executing with tokens:"); + while (iter.hasNext()) + { + Token token = iter.next(); + LOG.info(token); + if (token.getKind().equals(AMRMTokenIdentifier.KIND_NAME)) + { + iter.remove(); + } + } + allTokens = ByteBuffer.wrap(dob.getData(), 0, dob.getLength()); + + // Create appSubmitterUgi and add original tokens to it + appSubmitterUserName = System + .getenv(ApplicationConstants.Environment.USER.name()); + appSubmitterUgi = UserGroupInformation + .createRemoteUser(appSubmitterUserName); + appSubmitterUgi.addCredentials(credentials); + + AMRMClientAsync.CallbackHandler allocListener = new RMCallbackHandler(); + amRMClient = AMRMClientAsync.createAMRMClientAsync(1000, allocListener); + amRMClient.init(conf); + amRMClient.start(); + + containerListener = createNMCallbackHandler(); + nmClientAsync = new NMClientAsyncImpl(containerListener); + nmClientAsync.init(conf); + nmClientAsync.start(); + + startTimelineClient(conf); + if (timelineClient != null) + { + publishApplicationAttemptEvent(timelineClient, + appAttemptID.toString(), TFEvent.TF_APP_ATTEMPT_START, + domainId, appSubmitterUgi); + } + + // Setup local RPC Server to accept status requests directly from + // clients + + String localHostname = NetUtils.getHostname(); + + HttpServer2.Builder builder = new HttpServer2.Builder() + .setName("test") + .addEndpoint(URI.create("http://" + localHostname + ":0")) + .setFindPort(true); + + HttpServer2 server = builder.build(); + server.setAttribute(ApplicationMaster.class.getName(), this); + server.addServlet("status", "/", TFServlet.class); + + try + { + server.start(); + + InetSocketAddress addr = server.getConnectorAddress(0); + appMasterHostname = addr.getHostName(); + appMasterRpcPort = addr.getPort(); + appMasterTrackingUrl = "http://" + NetUtils.getHostPortString(addr); + + LOG.info(String.format("Server started: hostname=<%s>, port=<%d>, url=<%s>", + appMasterHostname, appMasterRpcPort, appMasterTrackingUrl)); + } catch (Exception e) + { + e.printStackTrace(); + } + + // Register self with ResourceManager + // This will start heartbeating to the RM + RegisterApplicationMasterResponse response = amRMClient + .registerApplicationMaster(appMasterHostname, appMasterRpcPort, + appMasterTrackingUrl); + // Dump out information about cluster capability as seen by the + // resource manager + int maxMem = response.getMaximumResourceCapability().getMemory(); + LOG.info("Max mem capability of resources in this cluster " + maxMem); + + int maxVCores = response.getMaximumResourceCapability() + .getVirtualCores(); + LOG.info("Max vcores capability of resources in this cluster " + + maxVCores); + + // A resource ask cannot exceed the max. + if (containerMemory > maxMem) + { + LOG.info("Container memory specified above max threshold of cluster." + + " Using max value." + + ", specified=" + + containerMemory + + ", max=" + maxMem); + containerMemory = maxMem; + } + + if (containerVirtualCores > maxVCores) + { + LOG.info("Container virtual cores specified above max threshold of cluster." + + " Using max value." + + ", specified=" + + containerVirtualCores + ", max=" + maxVCores); + containerVirtualCores = maxVCores; + } + + // open HDFS and write out resources for the session + FileSystem fs = FileSystem.get(conf); + tfSession.createResources(fs, localResources); + + // Setup ask for containers from RM + // Send request for containers to RM + // Until we get our fully allocated quota, we keep on polling RM for + // containers + // Keep looping until all the containers are launched and shell script + // executed on them ( regardless of success/failure). + + ArrayList requests = tfSession.getRequiredContainers(); + for (TFContainerRequest request : requests) + { + ContainerRequest containerAsk = setupContainerAskForRM(request); + amRMClient.addContainerRequest(containerAsk); + } + + numRequestedContainers.set(requests.size()); + } + + private boolean scanRegistryRecord() + { + boolean ready = false; + + RegistryOperations regOps = RegistryOperationsFactory.createInstance(conf); + regOps.start(); + + try + { + String parentPath = getRegistryBasePath() + "/components"; + + LOG.info("TF: attempting to list parentPath=<" + parentPath + ">"); + + List components = regOps.list(parentPath); + + for (String componentName : components) + { + String path = String.format("%s/%s", parentPath, componentName); + + LOG.info("TF: attempting to get path=<" + path + ">"); + + ServiceRecord record = regOps.resolve(path); + + LOG.info("TF: container_id=<" + componentName + "> record=<" + + record.toString() + ">"); + + String jobName = record.get(TASK_JOB_NAME); + int index = Integer.valueOf(record.get(TASK_JOB_INDEX)); + int port = Integer.valueOf(record.get(TASK_PORT)); + + ready = tfSession.updateAllocatedPort(jobName, index, port); + if (ready) + break; + } + } + catch (PathNotFoundException e) + { + // path not created yet + ready = false; + } + catch (Exception e) + { + LOG.error(e); + e.printStackTrace(); + + } + finally + { + regOps.stop(); + } + + return ready; + } + + private String getRegistryBasePath() { + // Format: + // /users/{username}/{serviceclass}/{instancename}/components/{componentname} + // NOTE: + // the real path will have "hadoop.registry.zk.root" prepended (default: /registry) + return String.format("/users/%s/%s/%s", appSubmitterUserName, + TFSession.REGISTRY_SERVICE_CLASS, + this.appIDString); + } + + private void setRegistryRecord(Map varmap) + { + String path = getRegistryBasePath(); + + ServiceRecord record = new ServiceRecord(); + record.set(YarnRegistryAttributes.YARN_ID, this.appIDString); + record.set(YarnRegistryAttributes.YARN_PERSISTENCE, PersistencePolicies.APPLICATION); + record.description = "YARN TensorFlow Application Master"; + + for (Entry entry : varmap.entrySet()) + { + String jobName = entry.getKey(); + record.set(jobName, entry.getValue()); + } + + LOG.info(String.format("Setting registry record %s to %s", path, record)); + RegistryOperations regOps = RegistryOperationsFactory.createInstance(conf); + regOps.start(); + + try + { + regOps.bind(path, record, BindFlags.OVERWRITE); + } + catch (Exception e) + { + LOG.error(e); + } + finally + { + regOps.stop(); + } + } + + @VisibleForTesting + private void startTimelineClient(final Configuration conf) throws YarnException, + IOException, InterruptedException + { + try + { + appSubmitterUgi.doAs(new PrivilegedExceptionAction() + { + @Override + public Void run() throws Exception + { + if (conf.getBoolean( + YarnConfiguration.TIMELINE_SERVICE_ENABLED, + YarnConfiguration.DEFAULT_TIMELINE_SERVICE_ENABLED)) + { + // Creating the Timeline Client + timelineClient = TimelineClient.createTimelineClient(); + timelineClient.init(conf); + timelineClient.start(); + } else + { + timelineClient = null; + LOG.warn("Timeline service is not enabled"); + } + return null; + } + }); + } catch (UndeclaredThrowableException e) + { + throw new YarnException(e.getCause()); + } + } + + @VisibleForTesting + private NMCallbackHandler createNMCallbackHandler() + { + return new NMCallbackHandler(this); + } + + @VisibleForTesting + private boolean finish() + { + boolean registrySet = false; + + // wait for completion. + while (!done && (numCompletedContainers.get() != numTotalContainers)) + { + try + { + if (!registrySet && numAllocatedContainers.get() >= numTotalContainers) + { + boolean ready = scanRegistryRecord(); + if (ready) { + LOG.info("TF: All tasks ready; signal tasks to start execution."); + this.setRegistryRecord(tfSession.getClusterSpec()); + registrySet = true; + } + else + { + LOG.info("TF: Not all tasks are ready yet; check again soon."); + } + Thread.sleep(1000); + } + else + { + // sleep + Thread.sleep(200); + } + + } catch (InterruptedException ex) + { + LOG.error(ex); + } + } + + if (timelineClient != null) + { + publishApplicationAttemptEvent(timelineClient, + appAttemptID.toString(), TFEvent.TF_APP_ATTEMPT_END, + domainId, appSubmitterUgi); + } + + // Join all launched threads + // needed for when we time out + // and we need to release containers + for (Thread launchThread : launchThreads) + { + try + { + launchThread.join(10000); + } catch (InterruptedException e) + { + LOG.info("Exception thrown in thread join: " + e.getMessage()); + e.printStackTrace(); + } + } + + // When the application completes, it should stop all running containers + LOG.info("Application completed. Stopping running containers"); + nmClientAsync.stop(); + + // When the application completes, it should send a finish application + // signal to the RM + LOG.info("Application completed. Signalling finish to RM"); + + boolean success = true; + + FinalApplicationStatus appStatus = tfSession.getFinalStatus(); + String appMessage = tfSession.getFinalMessage(); + if (appStatus != FinalApplicationStatus.SUCCEEDED) + { + LOG.info(appMessage); + success = false; + } + + try + { + amRMClient.unregisterApplicationMaster(appStatus, appMessage, null); + } catch (YarnException | IOException ex) + { + LOG.error("Failed to unregister application", ex); + } + + amRMClient.waitForServiceToStop(5000); + amRMClient.stop(); + + // Stop Timeline Client + if (timelineClient != null) + { + timelineClient.stop(); + } + + return success; + } + + @VisibleForTesting + @InterfaceAudience.Private + public enum TFEvent { + TF_APP_ATTEMPT_START, TF_APP_ATTEMPT_END, TF_CONTAINER_START, TF_CONTAINER_END + } + + @VisibleForTesting + @InterfaceAudience.Private + public enum TFEntity { + TF_APP_ATTEMPT, TF_CONTAINER + } + + @VisibleForTesting + @InterfaceAudience.Private + public enum TFInfo { + TF_TASK_NAME, TF_EXIT_STATUS, TF_STATE + } + + private class RMCallbackHandler implements AMRMClientAsync.CallbackHandler + { + @SuppressWarnings("unchecked") + @Override + public void onContainersCompleted( + List completedContainers) + { + LOG.info("Got response from RM for container ask, completedCnt=" + + completedContainers.size()); + + for (ContainerStatus containerStatus : completedContainers) + { + LOG.info(appAttemptID + + " got container status for containerID=" + + containerStatus.getContainerId() + ", state=" + + containerStatus.getState() + ", exitStatus=" + + containerStatus.getExitStatus() + ", diagnostics=" + + containerStatus.getDiagnostics()); + + // non complete containers should not be here + assert (containerStatus.getState() == ContainerState.COMPLETE); + + // increment counters for completed/failed containers + int exitStatus = containerStatus.getExitStatus(); + + tfSession.handleContainerTaskCompleted(containerStatus.getContainerId(), exitStatus); + + if (0 != exitStatus) + { + if (ContainerExitStatus.ABORTED != exitStatus) + { + // failed, but counts as completed + numCompletedContainers.incrementAndGet(); + numFailedContainers.incrementAndGet(); + } + LOG.info("Container failed." + + ", containerId=" + + containerStatus.getContainerId()); + } + else + { + // nothing to do + // container completed successfully + numCompletedContainers.incrementAndGet(); + LOG.info("Container completed successfully." + + ", containerId=" + + containerStatus.getContainerId()); + } + + if (timelineClient != null) + { + publishContainerEndEvent(timelineClient, containerStatus, + domainId, appSubmitterUgi); + } + } + + // ask for more containers if any failed + int askCount = numTotalContainers - numRequestedContainers.get(); + numRequestedContainers.addAndGet(askCount); + + if (askCount > 0) + { + for (int i = 0; i < askCount; ++i) + { + ContainerRequest containerAsk = setupContainerAskForRM(); + amRMClient.addContainerRequest(containerAsk); + } + } + + done = tfSession.isDone(); + + LOG.info("TF: numTotalContainers=" + numTotalContainers + + ", numCompletedContainers=" + numCompletedContainers + + ", numFailedContainers=" + numFailedContainers + + ", askCount=" + askCount + + ", done=" + done); + + } + + @Override + public void onContainersAllocated(List allocatedContainers) + { + LOG.info("Got response from RM for container ask, allocatedCnt=" + + allocatedContainers.size()); + + numAllocatedContainers.addAndGet(allocatedContainers.size()); + + for (Container allocatedContainer : allocatedContainers) + { + LOG.info("Launching task on a new container." + + ", containerId=" + allocatedContainer.getId() + + ", containerNode=" + + allocatedContainer.getNodeId().getHost() + ":" + + allocatedContainer.getNodeId().getPort() + + ", containerNodeURI=" + + allocatedContainer.getNodeHttpAddress() + + ", containerResourceMemory" + + allocatedContainer.getResource().getMemory() + + ", containerResourceVirtualCores" + + allocatedContainer.getResource().getVirtualCores()); + // + ", containerToken" + // +allocatedContainer.getContainerToken().getIdentifier().toString()); + + LaunchContainerRunnable runnableLaunchContainer = new LaunchContainerRunnable( + allocatedContainer, containerListener); + Thread launchThread = new Thread(runnableLaunchContainer); + + // launch and start the container on a separate thread to keep + // the main thread unblocked + // as all containers may not be allocated at one go. + launchThreads.add(launchThread); + launchThread.start(); + } + } + + @Override + public void onShutdownRequest() + { + done = true; + } + + @Override + public void onNodesUpdated(List updatedNodes) + { + } + + @Override + public float getProgress() + { + // set progress to deliver to RM on next heartbeat + return (float) numCompletedContainers.get() / numTotalContainers; + } + + @Override + public void onError(Throwable e) + { + done = true; + amRMClient.stop(); + } + } + + @VisibleForTesting + static class NMCallbackHandler implements NMClientAsync.CallbackHandler + { + + private ConcurrentMap containers = new ConcurrentHashMap<>(); + private final ApplicationMaster applicationMaster; + + public NMCallbackHandler(ApplicationMaster applicationMaster) + { + this.applicationMaster = applicationMaster; + } + + public void addContainer(ContainerId containerId, Container container) + { + containers.putIfAbsent(containerId, container); + } + + @Override + public void onContainerStopped(ContainerId containerId) + { + if (LOG.isDebugEnabled()) + { + LOG.debug("Succeeded to stop Container " + containerId); + } + containers.remove(containerId); + } + + @Override + public void onContainerStatusReceived(ContainerId containerId, + ContainerStatus containerStatus) + { + if (LOG.isDebugEnabled()) + { + LOG.debug("Container Status: id=" + containerId + ", status=" + + containerStatus); + } + } + + @Override + public void onContainerStarted(ContainerId containerId, + Map allServiceResponse) + { + if (LOG.isDebugEnabled()) + { + LOG.debug("Succeeded to start Container " + containerId); + } + Container container = containers.get(containerId); + if (container != null) + { + applicationMaster.nmClientAsync.getContainerStatusAsync( + containerId, container.getNodeId()); + } + if (applicationMaster.timelineClient != null) + { + ApplicationMaster.publishContainerStartEvent( + applicationMaster.timelineClient, container, + applicationMaster.domainId, + applicationMaster.appSubmitterUgi); + } + } + + @Override + public void onStartContainerError(ContainerId containerId, Throwable t) + { + LOG.error("Failed to start Container " + containerId); + containers.remove(containerId); + applicationMaster.numCompletedContainers.incrementAndGet(); + applicationMaster.numFailedContainers.incrementAndGet(); + } + + @Override + public void onGetContainerStatusError(ContainerId containerId, + Throwable t) + { + LOG.error("Failed to query the status of Container " + containerId); + } + + @Override + public void onStopContainerError(ContainerId containerId, Throwable t) + { + LOG.error("Failed to stop Container " + containerId); + containers.remove(containerId); + } + } + + /** + * Thread to connect to the {@link ContainerManagementProtocol} and launch + * the container that will execute the shell command. + */ + private class LaunchContainerRunnable implements Runnable + { + + // Allocated container + Container container; + + NMCallbackHandler containerListener; + + /** + * @param lcontainer + * Allocated container + * @param containerListener + * Callback handler of the container + */ + public LaunchContainerRunnable(Container lcontainer, + NMCallbackHandler containerListener) + { + this.container = lcontainer; + this.containerListener = containerListener; + } + + @Override + /* + Connects to CM, sets up container launch context + for shell command and eventually dispatches the container + start request to the CM. + */ + public void run() + { + LOG.info(String.format( + "TF: Adding container on host=<%s> to TensorFlowApp", + container.getNodeId().getHost())); + + // Set the local resources + Map myLocalResources = new ConcurrentHashMap<>(localResources); + + // make a copy of general environment variables, "addContainer" + // will add more container specific ones to it. + Map myShellEnv = new ConcurrentHashMap<>(shellEnv); + + // add DTF_TASK_JOB_NAME and DTF_TASK_INDEX to myShellEnv + boolean added = tfSession.addContainer(container, myShellEnv); + if (!added) + { + LOG.info("TF: got extra container; releasing container id=<" + container.getId() + ">"); + numAllocatedContainers.decrementAndGet(); + amRMClient.releaseAssignedContainer(container.getId()); + return; + } + + // Set the necessary command to execute on the allocated container + Vector vargs = new Vector<>(5); + + vargs.add(tfSession.getTaskStarterCommand()); + // Add log redirect params + vargs.add("1>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + + "/stdout"); + vargs.add("2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + + "/stderr"); + + // Get final commmand + StringBuilder command = new StringBuilder(); + for (CharSequence str : vargs) + { + command.append(str).append(" "); + } + + List commands = new ArrayList<>(); + commands.add(command.toString()); + + LOG.info(String.format("TF: constructed wrapper command=<%s>", commands)); + + // Set up ContainerLaunchContext, setting local resource, + // environment, command and token for constructor. + // + // Note for tokens: Set up tokens for the container. + // This is require for NodeManager and container to download files from DFS. + ContainerLaunchContext ctx = ContainerLaunchContext.newInstance( + myLocalResources, myShellEnv, commands, null, + allTokens.duplicate(), null); + containerListener.addContainer(container.getId(), container); + nmClientAsync.startContainerAsync(container, ctx); + } + } + + /** + * Setup the request that will be sent to the RM for the container ask. + * + * @return the setup ResourceRequest to be sent to RM + */ + private ContainerRequest setupContainerAskForRM(int virtualCores, int memory, int priority) + { + // set the priority for the request + Priority pri = Priority.newInstance(priority); + + // Set up resource type requirements + Resource capability = Resource.newInstance(memory, virtualCores); + + // Currently no locality requirements. + String[] nodes = null; + String[] racks = null; + + ContainerRequest request = new ContainerRequest(capability, nodes, racks, pri); + LOG.info("Requested container ask: " + request.toString()); + return request; + } + + private ContainerRequest setupContainerAskForRM() + { + return setupContainerAskForRM(this.containerVirtualCores, + this.containerMemory, this.requestPriority); + } + + private ContainerRequest setupContainerAskForRM(TFContainerRequest request) + { + return setupContainerAskForRM(request.getVirtualCores(), + request.getMemory(), request.getPriority()); + } + + private boolean fileExist(String filePath) + { + return new File(filePath).exists(); + } + + private String readContent(String filePath) throws IOException + { + DataInputStream ds = null; + try + { + ds = new DataInputStream(new FileInputStream(filePath)); + return ds.readUTF(); + } finally + { + org.apache.commons.io.IOUtils.closeQuietly(ds); + } + } + + private static void publishContainerStartEvent( + final TimelineClient timelineClient, Container container, + String domainId, UserGroupInformation ugi) + { + final TimelineEntity entity = new TimelineEntity(); + entity.setEntityId(container.getId().toString()); + entity.setEntityType(TFEntity.TF_CONTAINER.toString()); + entity.setDomainId(domainId); + entity.addPrimaryFilter("user", ugi.getShortUserName()); + TimelineEvent event = new TimelineEvent(); + event.setTimestamp(System.currentTimeMillis()); + event.setEventType(TFEvent.TF_CONTAINER_START.toString()); + event.addEventInfo("Node", container.getNodeId().toString()); + event.addEventInfo("Resources", container.getResource().toString()); + entity.addEvent(event); + + try + { + ugi.doAs(new PrivilegedExceptionAction() + { + @Override + public TimelinePutResponse run() throws Exception + { + return timelineClient.putEntities(entity); + } + }); + } catch (Exception e) + { + LOG.error("Container start event could not be published for " + + container.getId().toString(), + e instanceof UndeclaredThrowableException ? e.getCause() + : e); + } + } + + private void publishContainerEndEvent( + final TimelineClient timelineClient, ContainerStatus container, + String domainId, UserGroupInformation ugi) + { + final TimelineEntity entity = new TimelineEntity(); + entity.setEntityId(container.getContainerId().toString()); + entity.setEntityType(TFEntity.TF_CONTAINER.toString()); + entity.setDomainId(domainId); + entity.addPrimaryFilter("user", ugi.getShortUserName()); + TimelineEvent event = new TimelineEvent(); + event.setTimestamp(System.currentTimeMillis()); + event.setEventType(TFEvent.TF_CONTAINER_END.toString()); + event.addEventInfo(TFInfo.TF_STATE.toString(), container.getState().name()); + event.addEventInfo(TFInfo.TF_EXIT_STATUS.toString(), container.getExitStatus()); + + String taskName = tfSession.getJobAndIndex(container.getContainerId()); + event.addEventInfo(TFInfo.TF_TASK_NAME.toString(), taskName); + + entity.addEvent(event); + try + { + timelineClient.putEntities(entity); + } catch (YarnException | IOException e) + { + LOG.error("Container end event could not be published for " + + container.getContainerId().toString(), e); + } + } + + private static void publishApplicationAttemptEvent( + final TimelineClient timelineClient, String appAttemptId, + TFEvent appEvent, String domainId, UserGroupInformation ugi) + { + final TimelineEntity entity = new TimelineEntity(); + entity.setEntityId(appAttemptId); + entity.setEntityType(TFEntity.TF_APP_ATTEMPT.toString()); + entity.setDomainId(domainId); + entity.addPrimaryFilter("user", ugi.getShortUserName()); + TimelineEvent event = new TimelineEvent(); + event.setEventType(appEvent.toString()); + event.setTimestamp(System.currentTimeMillis()); + entity.addEvent(event); + try + { + timelineClient.putEntities(entity); + } catch (YarnException | IOException e) + { + LOG.error( + "App Attempt " + + (appEvent.equals(TFEvent.TF_APP_ATTEMPT_START) ? "start" + : "end") + + " event could not be published for " + + appAttemptId, e); + } + } + + void printHtmlStatus(PrintWriter out) + { + out.println(""); + out.println("

YARN TensorFlow Application Status Page

"); + + tfSession.printHtmlStatusTable(out); + + out.println(""); + } + +} diff --git a/yarn/src/main/java/org/tensorflow/hadoop/yarn/Client.java b/yarn/src/main/java/org/tensorflow/hadoop/yarn/Client.java new file mode 100644 index 00000000..71d81b43 --- /dev/null +++ b/yarn/src/main/java/org/tensorflow/hadoop/yarn/Client.java @@ -0,0 +1,738 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.hadoop.yarn; + +/** + * TensorFlow launcher for YARN + * + * Modified from YARN sample application: DistributedShell. + */ + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Vector; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.GnuParser; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; +import org.apache.commons.io.IOUtils; +import org.apache.commons.lang.StringUtils; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.io.DataOutputBuffer; +import org.apache.hadoop.security.Credentials; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.security.token.Token; +import org.apache.hadoop.yarn.api.ApplicationConstants; +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment; +import org.apache.hadoop.yarn.api.protocolrecords.GetNewApplicationResponse; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ApplicationSubmissionContext; +import org.apache.hadoop.yarn.api.records.ContainerLaunchContext; +import org.apache.hadoop.yarn.api.records.LocalResource; +import org.apache.hadoop.yarn.api.records.LocalResourceType; +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility; +import org.apache.hadoop.yarn.api.records.NodeReport; +import org.apache.hadoop.yarn.api.records.NodeState; +import org.apache.hadoop.yarn.api.records.Priority; +import org.apache.hadoop.yarn.api.records.QueueACL; +import org.apache.hadoop.yarn.api.records.QueueInfo; +import org.apache.hadoop.yarn.api.records.QueueUserACLInfo; +import org.apache.hadoop.yarn.api.records.Resource; +import org.apache.hadoop.yarn.api.records.YarnClusterMetrics; +import org.apache.hadoop.yarn.api.records.timeline.TimelineDomain; +import org.apache.hadoop.yarn.client.api.TimelineClient; +import org.apache.hadoop.yarn.client.api.YarnClient; +import org.apache.hadoop.yarn.client.api.YarnClientApplication; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.exceptions.YarnException; +import org.apache.hadoop.yarn.util.ConverterUtils; +import org.apache.hadoop.yarn.util.timeline.TimelineUtils; + +@InterfaceAudience.Public +@InterfaceStability.Unstable +public class Client { + + private static final Log LOG = LogFactory.getLog(Client.class); + + // Configuration + private Configuration conf; + private YarnClient yarnClient; + // Application master specific info to register a new Application with RM/ASM + private String appName = ""; + // App master priority + private int amPriority = 0; + // Queue for App master + private String amQueue = ""; + // Amt. of memory resource to request for to run the App Master + private int amMemory = 10; + // Amt. of virtual core resource to request for to run the App Master + private int amVCores = 1; + + // Application master jar file + private String appMasterJar = ""; + // Main class to invoke application master + private final String appMasterMainClass; + + // Shell command to be executed + private String taskCmd = ""; + // Location of shell script + private String taskScript = ""; + // Args to be passed to the shell command + private String[] taskArgs = new String[] {}; + // Input path + private String inputPath = ""; + // Output path + private String outputPath = ""; + // Docker engine to run the task command + private String dockerImage = ""; + // Env variables to be setup for the shell command + private Map shellEnv = new HashMap<>(); + // Shell Command Container priority + private int shellCmdPriority = 0; + + // Amt of memory to request for container in which shell script will be executed + private int containerMemory = 10; + // Amt. of virtual cores to request for container in which shell script will be executed + private int containerVirtualCores = 1; + // No. of containers in which the shell script needs to be executed + private String numContainers = ""; + private String nodeLabelExpression = null; + + // log4j.properties file + // if available, add to local resources and set into classpath + private String log4jPropFile = ""; + + // flag to indicate whether to keep containers across application attempts. + private boolean keepContainers = false; + + private long attemptFailuresValidityInterval = -1; + + // Debug flag + private boolean debugFlag = false; + + // Timeline domain ID + private String domainId = null; + + // Flag to indicate whether to create the domain of the given ID + private boolean toCreateDomain = false; + + // Timeline domain reader access control + private String viewACLs = null; + + // Timeline domain writer access control + private String modifyACLs = null; + + // Command line options + private Options opts; + + private boolean tensorboardFlag = false; + + private static final String shellCommandPath = "shellCommands"; + private static final String shellArgsPath = "shellArgs"; + private static final String appMasterJarPath = "AppMaster.jar"; + // Hardcoded path to custom log_properties + private static final String log4jPath = "log4j.properties"; + + public static final String SCRIPT_PATH = "ExecScript"; + + /** + * @param args Command line arguments + */ + public static void main(String[] args) { + boolean result = false; + try { + Client client = new Client(); + LOG.info("Initializing Client"); + try { + boolean doRun = client.init(args); + if (!doRun) { + System.exit(0); + } + } catch (IllegalArgumentException e) { + System.err.println(e.getLocalizedMessage()); + client.printUsage(); + System.exit(-1); + } + result = client.run(); + } catch (Throwable t) { + LOG.fatal("Error running Client", t); + System.exit(1); + } + if (result) { + LOG.info("Application submitted successfully"); + System.exit(0); + } + LOG.error("Application failed to submitted successfully"); + System.exit(2); + } + + /** + */ + public Client(Configuration conf) throws Exception { + this( + ApplicationMaster.class.getName(), + conf); + } + + Client(String appMasterMainClass, Configuration conf) { + this.conf = conf; + this.appMasterMainClass = appMasterMainClass; + yarnClient = YarnClient.createYarnClient(); + yarnClient.init(conf); + opts = new Options(); + opts.addOption("appname", true, "Application Name. Default: " + TFConstants.DEFAULT_APPNAME); + opts.addOption("priority", true, "Application Priority. Default 0"); + opts.addOption("queue", true, "RM Queue in which this application is to be submitted"); + opts.addOption("master_memory", true, "Amount of memory in MB to be requested to run the application master"); + opts.addOption("master_vcores", true, "Amount of virtual cores to be requested to run the application master"); + opts.addOption("jar", true, "Jar file containing the application master"); + opts.addOption("task_script", true, "Location of the task script, it will be copied to task execution environment."); + opts.addOption("task_cmd", true, "Task execute command. At least one of 'task_script' or 'task_cmd' must be defined."); + opts.addOption("task_args", true, "Command line args for the task program." + + "Multiple args can be separated by empty space."); + opts.getOption("task_args").setArgs(Option.UNLIMITED_VALUES); + opts.addOption("input_path", true, "Input path for task program"); + opts.addOption("output_path", true, "Output path for task program"); + opts.addOption("docker_image", true, "Docker image for running the tasks"); + opts.addOption("shell_env", true, "Environment for task program. Specified as env_key=env_val pairs"); + opts.addOption("container_memory", true, "Amount of memory in MB to be requested to run the task program"); + opts.addOption("container_vcores", true, "Amount of virtual cores to be requested to run the task program"); + opts.addOption("num_containers", true, "Format :, separated by comma. (ie: ps:2,worker:4)"); + opts.addOption("log_properties", true, "log4j.properties file"); + opts.addOption("enable_tensorboard", false, "Start TensorBoard as part of job"); + opts.addOption("debug", false, "Dump out debug information"); + opts.addOption("domain", true, "ID of the timeline domain where the " + + "timeline entities will be put"); + opts.addOption("view_acls", true, "Users and groups that allowed to " + + "view the timeline entities in the given domain"); + opts.addOption("modify_acls", true, "Users and groups that allowed to " + + "modify the timeline entities in the given domain"); + opts.addOption("create", false, "Flag to indicate whether to create the " + + "domain specified with -domain."); + opts.addOption("help", false, "Print usage"); + + opts.addOption("node_label_expression", true, + "Node label expression to determine the nodes" + + " where all the containers of this application" + + " will be allocated, \"\" means containers" + + " can be allocated anywhere, if you don't specify the option," + + " default node_label_expression of queue will be used."); + } + + /** + */ + public Client() throws Exception { + this(new YarnConfiguration()); + } + + /** + * Helper function to print out usage + */ + private void printUsage() { + new HelpFormatter().printHelp("Client", opts); + } + + /** + * Parse command line options + * @param args Parsed command line options + * @return Whether the init was successful to run the client + * @throws ParseException + */ + public boolean init(String[] args) throws ParseException { + + CommandLine cliParser = new GnuParser().parse(opts, args); + + if (args.length == 0) { + throw new IllegalArgumentException("No args specified for client to initialize"); + } + + if (cliParser.hasOption("log_properties")) { + String log4jPath = cliParser.getOptionValue("log_properties"); + try { + Log4jPropertyHelper.updateLog4jConfiguration(Client.class, log4jPath); + } catch (Exception e) { + LOG.warn("Can not set up custom log4j properties. " + e); + } + } + + if (cliParser.hasOption("help")) { + printUsage(); + return false; + } + + if (cliParser.hasOption("debug")) { + debugFlag = true; + + } + + if (cliParser.hasOption("enable_tensorboard")) { + tensorboardFlag = true; + } + + if (cliParser.hasOption("keep_containers_across_application_attempts")) { + LOG.info("keep_containers_across_application_attempts"); + keepContainers = true; + } + + appName = cliParser.getOptionValue("appname", TFConstants.DEFAULT_APPNAME); + amPriority = Integer.parseInt(cliParser.getOptionValue("priority", "0")); + amQueue = cliParser.getOptionValue("queue", "default"); + amMemory = Integer.parseInt(cliParser.getOptionValue("master_memory", "10")); + amVCores = Integer.parseInt(cliParser.getOptionValue("master_vcores", "1")); + + if (amMemory < 0) { + throw new IllegalArgumentException("Invalid memory specified for application master, exiting." + + " Specified memory=" + amMemory); + } + if (amVCores < 0) { + throw new IllegalArgumentException("Invalid virtual cores specified for application master, exiting." + + " Specified virtual cores=" + amVCores); + } + + if (!cliParser.hasOption("jar")) { + throw new IllegalArgumentException("No jar file specified for application master"); + } + + appMasterJar = cliParser.getOptionValue("jar"); + + taskScript = cliParser.getOptionValue("task_script", ""); + taskCmd = cliParser.getOptionValue("task_cmd", ""); + + if (taskScript.isEmpty() && taskCmd.isEmpty()) { + throw new IllegalArgumentException( + "No task script nor task command specified to be executed by application master"); + } + + if (cliParser.hasOption("task_args")) { + taskArgs = cliParser.getOptionValues("task_args"); + } + + inputPath = cliParser.getOptionValue("input_path", ""); + outputPath = cliParser.getOptionValue("output_path", ""); + dockerImage = cliParser.getOptionValue("docker_image", ""); + + if (cliParser.hasOption("shell_env")) { + String envs[] = cliParser.getOptionValues("shell_env"); + for (String env : envs) { + env = env.trim(); + int index = env.indexOf('='); + if (index == -1) { + shellEnv.put(env, ""); + continue; + } + String key = env.substring(0, index); + String val = ""; + if (index < (env.length()-1)) { + val = env.substring(index+1); + } + shellEnv.put(key, val); + } + } + + containerMemory = Integer.parseInt(cliParser.getOptionValue("container_memory", "10")); + containerVirtualCores = Integer.parseInt(cliParser.getOptionValue("container_vcores", "1")); + numContainers = cliParser.getOptionValue("num_containers", ""); + + if (containerMemory < 0 || containerVirtualCores < 0 || numContainers.isEmpty()) { + throw new IllegalArgumentException("Invalid no. of containers or container memory/vcores specified," + + " exiting." + + " Specified containerMemory=" + containerMemory + + ", containerVirtualCores=" + containerVirtualCores + + ", numContainer=" + numContainers); + } + + nodeLabelExpression = cliParser.getOptionValue("node_label_expression", null); + + log4jPropFile = cliParser.getOptionValue("log_properties", ""); + + // Get timeline domain options + if (cliParser.hasOption("domain")) { + domainId = cliParser.getOptionValue("domain"); + toCreateDomain = cliParser.hasOption("create"); + if (cliParser.hasOption("view_acls")) { + viewACLs = cliParser.getOptionValue("view_acls"); + } + if (cliParser.hasOption("modify_acls")) { + modifyACLs = cliParser.getOptionValue("modify_acls"); + } + } + + return true; + } + + /** + * Main run function for the client + * @return true if application completed successfully + * @throws IOException + * @throws YarnException + */ + public boolean run() throws IOException, YarnException { + + LOG.info("Running Client"); + yarnClient.start(); + + YarnClusterMetrics clusterMetrics = yarnClient.getYarnClusterMetrics(); + LOG.info("Got Cluster metric info from ASM" + + ", numNodeManagers=" + clusterMetrics.getNumNodeManagers()); + + List clusterNodeReports = yarnClient.getNodeReports( + NodeState.RUNNING); + LOG.info("Got Cluster node info from ASM"); + for (NodeReport node : clusterNodeReports) { + LOG.info("Got node report from ASM for" + + ", nodeId=" + node.getNodeId() + + ", nodeAddress" + node.getHttpAddress() + + ", nodeRackName" + node.getRackName() + + ", nodeNumContainers" + node.getNumContainers()); + } + + QueueInfo queueInfo = yarnClient.getQueueInfo(this.amQueue); + LOG.info("Queue info" + + ", queueName=" + queueInfo.getQueueName() + + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity() + + ", queueMaxCapacity=" + queueInfo.getMaximumCapacity() + + ", queueApplicationCount=" + queueInfo.getApplications().size() + + ", queueChildQueueCount=" + queueInfo.getChildQueues().size()); + + List listAclInfo = yarnClient.getQueueAclsInfo(); + for (QueueUserACLInfo aclInfo : listAclInfo) { + for (QueueACL userAcl : aclInfo.getUserAcls()) { + LOG.info("User ACL Info for Queue" + + ", queueName=" + aclInfo.getQueueName() + + ", userAcl=" + userAcl.name()); + } + } + + if (domainId != null && domainId.length() > 0 && toCreateDomain) { + prepareTimelineDomain(); + } + + // Get a new application id + YarnClientApplication app = yarnClient.createApplication(); + GetNewApplicationResponse appResponse = app.getNewApplicationResponse(); + + int maxMem = appResponse.getMaximumResourceCapability().getMemory(); + LOG.info("Max mem capabililty of resources in this cluster " + maxMem); + + // A resource ask cannot exceed the max. + if (amMemory > maxMem) { + LOG.info("AM memory specified above max threshold of cluster. Using max value." + + ", specified=" + amMemory + + ", max=" + maxMem); + amMemory = maxMem; + } + + int maxVCores = appResponse.getMaximumResourceCapability().getVirtualCores(); + LOG.info("Max virtual cores capabililty of resources in this cluster " + maxVCores); + + if (amVCores > maxVCores) { + LOG.info("AM virtual cores specified above max threshold of cluster. " + + "Using max value." + ", specified=" + amVCores + + ", max=" + maxVCores); + amVCores = maxVCores; + } + + // set the application name + ApplicationSubmissionContext appContext = app.getApplicationSubmissionContext(); + ApplicationId appId = appContext.getApplicationId(); + + appContext.setKeepContainersAcrossApplicationAttempts(keepContainers); + appContext.setApplicationName(appName); + + if (attemptFailuresValidityInterval >= 0) { + appContext + .setAttemptFailuresValidityInterval(attemptFailuresValidityInterval); + } + + // set local resources for the application master + // local files or archives as needed + // In this scenario, the jar file for the application master is part of the local resources + Map localResources = new HashMap<>(); + + LOG.info("Copy App Master jar from local filesystem and add to local environment"); + // Copy the application master jar to the filesystem + // Create a local resource to point to the destination jar path + FileSystem fs = FileSystem.get(conf); + addToLocalResources(fs, appMasterJar, appMasterJarPath, appId.toString(), + localResources, null); + + // Set the log4j properties if needed + if (!log4jPropFile.isEmpty()) { + addToLocalResources(fs, log4jPropFile, log4jPath, appId.toString(), + localResources, null); + } + + // The shell script has to be made available on the final container(s) + // where it will be executed. + // To do this, we need to first copy into the filesystem that is visible + // to the yarn framework. + // We do not need to set this as a local resource for the application + // master as the application master does not need it. + String hdfsShellScriptLocation = ""; + long hdfsShellScriptLen = 0; + long hdfsShellScriptTimestamp = 0; + if (!taskScript.isEmpty()) { + Path shellSrc = new Path(taskScript); + String shellPathSuffix = + appName + "/" + appId.toString() + "/" + SCRIPT_PATH; + Path shellDst = + new Path(fs.getHomeDirectory(), shellPathSuffix); + fs.copyFromLocalFile(false, true, shellSrc, shellDst); + hdfsShellScriptLocation = shellDst.toUri().toString(); + FileStatus shellFileStatus = fs.getFileStatus(shellDst); + hdfsShellScriptLen = shellFileStatus.getLen(); + hdfsShellScriptTimestamp = shellFileStatus.getModificationTime(); + } + + if (!taskCmd.isEmpty()) { + addToLocalResources(fs, null, shellCommandPath, appId.toString(), + localResources, taskCmd); + } + + if (taskArgs.length > 0) { + addToLocalResources(fs, null, shellArgsPath, appId.toString(), + localResources, StringUtils.join(taskArgs, " ")); + } + + // Set the necessary security tokens as needed + //amContainer.setContainerTokens(containerToken); + + // Set the env variables to be setup in the env where the application master will be run + LOG.info("Set the environment for the application master"); + Map env = new HashMap(); + + // put location of shell script into env + // using the env info, the application master will create the correct local resource for the + // eventual containers that will be launched to execute the shell scripts + env.put(TFConstants.SCRIPTLOCATION, hdfsShellScriptLocation); + env.put(TFConstants.SCRIPTTIMESTAMP, Long.toString(hdfsShellScriptTimestamp)); + env.put(TFConstants.SHELLSCRIPTLEN, Long.toString(hdfsShellScriptLen)); + if (domainId != null && domainId.length() > 0) { + env.put(TFConstants.TIMELINEDOMAIN, domainId); + } + + // Add AppMaster.jar location to classpath + // At some point we should not be required to add + // the hadoop specific classpaths to the env. + // It should be provided out of the box. + // For now setting all required classpaths including + // the classpath to "." for the application jar + StringBuilder classPathEnv = new StringBuilder(Environment.CLASSPATH.$$()) + .append(ApplicationConstants.CLASS_PATH_SEPARATOR).append("./*"); + for (String c : conf.getStrings( + YarnConfiguration.YARN_APPLICATION_CLASSPATH, + YarnConfiguration.DEFAULT_YARN_CROSS_PLATFORM_APPLICATION_CLASSPATH)) { + classPathEnv.append(ApplicationConstants.CLASS_PATH_SEPARATOR); + classPathEnv.append(c.trim()); + } + classPathEnv.append(ApplicationConstants.CLASS_PATH_SEPARATOR).append( + "./log4j.properties"); + + // add the runtime classpath needed for tests to work + if (conf.getBoolean(YarnConfiguration.IS_MINI_YARN_CLUSTER, false)) { + classPathEnv.append(':'); + classPathEnv.append(System.getProperty("java.class.path")); + } + + env.put("CLASSPATH", classPathEnv.toString()); + + // Set the necessary command to execute the application master + Vector vargs = new Vector(30); + + // Set java executable command + LOG.info("Setting up app master command"); + vargs.add(Environment.JAVA_HOME.$$() + "/bin/java"); + // Set Xmx based on am memory size + vargs.add("-Xmx" + amMemory + "m"); + // Set class name + vargs.add(appMasterMainClass); + // Set params for Application Master + vargs.add("--container_memory " + String.valueOf(containerMemory)); + vargs.add("--container_vcores " + String.valueOf(containerVirtualCores)); + vargs.add("--num_containers " + numContainers); + vargs.add("--appname " + appName); + if (!inputPath.isEmpty()) { + vargs.add("--input_path " + inputPath); + } + if (!outputPath.isEmpty()) { + vargs.add("--output_path " + outputPath); + } + if (!dockerImage.isEmpty()) { + vargs.add("--docker_image " + dockerImage); + } + if (null != nodeLabelExpression) { + appContext.setNodeLabelExpression(nodeLabelExpression); + } + vargs.add("--priority " + String.valueOf(shellCmdPriority)); + + for (Map.Entry entry : shellEnv.entrySet()) { + vargs.add("--shell_env " + entry.getKey() + "=" + entry.getValue()); + } + if (debugFlag) { + vargs.add("--debug"); + } + if (tensorboardFlag) { + vargs.add("--enable_tensorboard"); + } + + vargs.add("1>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/AppMaster.stdout"); + vargs.add("2>" + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/AppMaster.stderr"); + + // Get final commmand + StringBuilder command = new StringBuilder(); + for (CharSequence str : vargs) { + command.append(str).append(" "); + } + + LOG.info("Completed setting up app master command " + command.toString()); + List commands = new ArrayList(); + commands.add(command.toString()); + + // Set up the container launch context for the application master + ContainerLaunchContext amContainer = ContainerLaunchContext.newInstance( + localResources, env, commands, null, null, null); + + // Set up resource type requirements + // For now, both memory and vcores are supported, so we set memory and + // vcores requirements + Resource capability = Resource.newInstance(amMemory, amVCores); + appContext.setResource(capability); + + // Service data is a binary blob that can be passed to the application + // Not needed in this scenario + // amContainer.setServiceData(serviceData); + + // Setup security tokens + if (UserGroupInformation.isSecurityEnabled()) { + // Note: Credentials class is marked as LimitedPrivate for HDFS and MapReduce + Credentials credentials = new Credentials(); + String tokenRenewer = conf.get(YarnConfiguration.RM_PRINCIPAL); + if (tokenRenewer == null || tokenRenewer.length() == 0) { + throw new IOException( + "Can't get Master Kerberos principal for the RM to use as renewer"); + } + + // For now, only getting tokens for the default file-system. + final Token tokens[] = + fs.addDelegationTokens(tokenRenewer, credentials); + if (tokens != null) { + for (Token token : tokens) { + LOG.info("Got dt for " + fs.getUri() + "; " + token); + } + } + DataOutputBuffer dob = new DataOutputBuffer(); + credentials.writeTokenStorageToStream(dob); + ByteBuffer fsTokens = ByteBuffer.wrap(dob.getData(), 0, dob.getLength()); + amContainer.setTokens(fsTokens); + } + + appContext.setAMContainerSpec(amContainer); + + // Set the priority for the application master + Priority pri = Priority.newInstance(amPriority); + appContext.setPriority(pri); + + // Set the queue to which this application is to be submitted in the RM + appContext.setQueue(amQueue); + + // Submit the application to the applications manager + // SubmitApplicationResponse submitResp = applicationsManager.submitApplication(appRequest); + // Ignore the response as either a valid response object is returned on success + // or an exception thrown to denote some form of a failure + LOG.info("Submitting application to ASM"); + + yarnClient.submitApplication(appContext); + + return true; + + } + + private void addToLocalResources(FileSystem fs, String fileSrcPath, + String fileDstPath, String appId, Map localResources, + String resources) throws IOException { + String suffix = + appName + "/" + appId + "/" + fileDstPath; + Path dst = + new Path(fs.getHomeDirectory(), suffix); + if (fileSrcPath == null) { + FSDataOutputStream ostream = null; + try { + ostream = FileSystem + .create(fs, dst, new FsPermission((short) 0710)); + ostream.writeUTF(resources); + } finally { + IOUtils.closeQuietly(ostream); + } + } else { + fs.copyFromLocalFile(new Path(fileSrcPath), dst); + } + FileStatus scFileStatus = fs.getFileStatus(dst); + LocalResource scRsrc = + LocalResource.newInstance( + ConverterUtils.getYarnUrlFromURI(dst.toUri()), + LocalResourceType.FILE, LocalResourceVisibility.APPLICATION, + scFileStatus.getLen(), scFileStatus.getModificationTime()); + localResources.put(fileDstPath, scRsrc); + } + + private void prepareTimelineDomain() { + TimelineClient timelineClient = null; + if (conf.getBoolean(YarnConfiguration.TIMELINE_SERVICE_ENABLED, + YarnConfiguration.DEFAULT_TIMELINE_SERVICE_ENABLED)) { + timelineClient = TimelineClient.createTimelineClient(); + timelineClient.init(conf); + timelineClient.start(); + } else { + LOG.warn("Cannot put the domain " + domainId + + " because the timeline service is not enabled"); + return; + } + try { + TimelineDomain domain = new TimelineDomain(); + domain.setId(domainId); + domain.setReaders( + viewACLs != null && viewACLs.length() > 0 ? viewACLs : " "); + domain.setWriters( + modifyACLs != null && modifyACLs.length() > 0 ? modifyACLs : " "); + timelineClient.putDomain(domain); + LOG.info("Put the timeline domain: " + + TimelineUtils.dumpTimelineRecordtoJSON(domain)); + } catch (Exception e) { + LOG.error("Error when putting the timeline domain", e); + } finally { + timelineClient.stop(); + } + } +} diff --git a/yarn/src/main/java/org/tensorflow/hadoop/yarn/Log4jPropertyHelper.java b/yarn/src/main/java/org/tensorflow/hadoop/yarn/Log4jPropertyHelper.java new file mode 100644 index 00000000..149f4a6a --- /dev/null +++ b/yarn/src/main/java/org/tensorflow/hadoop/yarn/Log4jPropertyHelper.java @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.hadoop.yarn; + +/** + * TensorFlow launcher for YARN + * + * Modified from YARN sample application: DistributedShell. + */ + +import java.io.FileInputStream; +import java.io.InputStream; +import java.util.Map.Entry; +import java.util.Properties; + +import org.apache.commons.io.IOUtils; +import org.apache.log4j.LogManager; +import org.apache.log4j.PropertyConfigurator; + +public class Log4jPropertyHelper { + + public static void updateLog4jConfiguration(Class targetClass, + String log4jPath) throws Exception { + Properties customProperties = new Properties(); + FileInputStream fs = null; + InputStream is = null; + try { + fs = new FileInputStream(log4jPath); + is = targetClass.getResourceAsStream("/log4j.properties"); + customProperties.load(fs); + Properties originalProperties = new Properties(); + originalProperties.load(is); + for (Entry entry : customProperties.entrySet()) { + originalProperties.setProperty(entry.getKey().toString(), entry + .getValue().toString()); + } + LogManager.resetConfiguration(); + PropertyConfigurator.configure(originalProperties); + }finally { + IOUtils.closeQuietly(is); + IOUtils.closeQuietly(fs); + } + } +} diff --git a/yarn/src/main/java/org/tensorflow/hadoop/yarn/TFConstants.java b/yarn/src/main/java/org/tensorflow/hadoop/yarn/TFConstants.java new file mode 100644 index 00000000..0d421423 --- /dev/null +++ b/yarn/src/main/java/org/tensorflow/hadoop/yarn/TFConstants.java @@ -0,0 +1,62 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.hadoop.yarn; + +/** + * TensorFlow launcher for YARN + * + * Modified from YARN sample application: DistributedShell. + */ +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; + +/** + * Constants used in both Client and ApplicationMaster + */ +@InterfaceAudience.Public +@InterfaceStability.Unstable +public class TFConstants { + + /** + * Environment key name pointing to the shell script's location + */ + public static final String SCRIPTLOCATION = "SCRIPTLOCATION"; + + /** + * Environment key name denoting the file timestamp for the shell script. + * Used to validate the local resource. + */ + public static final String SCRIPTTIMESTAMP = "SCRIPTTIMESTAMP"; + + /** + * Environment key name denoting the file content length for the shell script. + * Used to validate the local resource. + */ + public static final String SHELLSCRIPTLEN = "SHELLSCRIPTLEN"; + + /** + * Environment key name denoting the timeline domain ID. + */ + public static final String TIMELINEDOMAIN = "TIMELINEDOMAIN"; + + /** + * Default application name + */ + public static final String DEFAULT_APPNAME = "TensorFlow"; +} diff --git a/yarn/src/main/java/org/tensorflow/hadoop/yarn/TFContainerRequest.java b/yarn/src/main/java/org/tensorflow/hadoop/yarn/TFContainerRequest.java new file mode 100644 index 00000000..9cf17b1e --- /dev/null +++ b/yarn/src/main/java/org/tensorflow/hadoop/yarn/TFContainerRequest.java @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.hadoop.yarn; + +/** + * TensorFlow launcher for YARN + * + * Modified from YARN sample application: DistributedShell. + */ + +public class TFContainerRequest { + private int virtualCores; + private int memory; + private int priority; + + private TFContainerRequest() {} + + public TFContainerRequest(int virtualCores, int memory, int priority) + { + this.virtualCores = virtualCores; + this.memory = memory; + this.priority = priority; + } + + public TFContainerRequest(TFContainerRequest that) { + this.virtualCores = that.virtualCores; + this.memory = that.memory; + this.priority = that.priority; + } + + public int getVirtualCores() { + return virtualCores; + } + + public int getMemory() { + return memory; + } + + public int getPriority() { + return priority; + } +} diff --git a/yarn/src/main/java/org/tensorflow/hadoop/yarn/TFServlet.java b/yarn/src/main/java/org/tensorflow/hadoop/yarn/TFServlet.java new file mode 100644 index 00000000..a2addc6a --- /dev/null +++ b/yarn/src/main/java/org/tensorflow/hadoop/yarn/TFServlet.java @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.hadoop.yarn; + +/** + * TensorFlow launcher for YARN + * + * Modified from YARN sample application: DistributedShell. + */ + +import java.io.IOException; +import java.io.PrintWriter; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +public class TFServlet extends HttpServlet +{ + private static final long serialVersionUID = 7965676366699736489L; + + @Override + public void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException + { + ApplicationMaster applicationMaster = (ApplicationMaster) getServletContext().getAttribute(ApplicationMaster.class.getName()); + PrintWriter out = response.getWriter(); + response.setContentType("text/html"); + applicationMaster.printHtmlStatus(out); + out.close(); + } +} diff --git a/yarn/src/main/java/org/tensorflow/hadoop/yarn/TFSession.java b/yarn/src/main/java/org/tensorflow/hadoop/yarn/TFSession.java new file mode 100644 index 00000000..ac8d5d22 --- /dev/null +++ b/yarn/src/main/java/org/tensorflow/hadoop/yarn/TFSession.java @@ -0,0 +1,708 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.hadoop.yarn; + +/** + * TensorFlow launcher for YARN + * + * Modified from YARN sample application: DistributedShell. + */ + +import org.apache.commons.io.IOUtils; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.yarn.api.records.*; +import org.apache.hadoop.yarn.util.ConverterUtils; + +import java.io.IOException; +import java.io.InputStream; +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +class TFSession { + static final String REGISTRY_SERVICE_CLASS = "yarn-tensorflow"; + private static final Log LOG = LogFactory.getLog(TFSession.class); + + private static final String linux_bash_command = "bash"; + private static final String TaskProgramFile = "taskProgram"; + private static final String TaskScriptFile = "taskScript"; + private static final String TaskStarterPath = "taskStarter"; + private static final String TaskStarterModule = "wrapper"; + private static final String TensorBoardProgramFile = "tensorboardProgram"; + private static final String TaskStarterResource = "task-starter.zip"; + + private static final String DTF_TENSORBOARD_JOBNAME = "TensorBoard"; + private static final String DTF_TASK_PROGRAM = "DTF_TASK_PROGRAM"; + private static final String DTF_TASK_SCRIPT = "DTF_TASK_SCRIPT"; + private static final String DTF_TASK_JOB_NAME = "DTF_TASK_JOB_NAME"; + private static final String DTF_TASK_INDEX = "DTF_TASK_INDEX"; + private static final String DTF_INPUT_PATH = "DTF_INPUT_PATH"; + private static final String DTF_OUTPUT_PATH = "DTF_OUTPUT_PATH"; + private static final String DTF_DOCKER_IMAGE = "DTF_DOCKER_IMAGE"; + private static final String DTF_APPLICATION_ID = "DTF_APPLICATION_ID"; + private static final String DTF_ZK_HOSTS = "DTF_ZK_HOSTS"; + private static final String DTF_SERVICE_CLASS = "DTF_SERVICE_CLASS"; + private static final String DTF_JOBNAME_HOSTS_FORMAT = "DTF_%s_HOSTS"; + + private static final String DEF_PS_JOB_NAME = "ps" ; + private static final String DEF_CHIEF_JOB_NAME = "worker"; + private static final int DEF_CHIEF_INDEX = -1; // no chief defined by default + + public enum TaskType { + TASK_TYPE_CHIEF, TASK_TYPE_PARAMETER_SERVER, TASK_TYPE_OTHERS + } + + // Application Name + private String appName = ""; + // Application ID + private String appIDString; + // Shell command to be executed + private String taskCommand = ""; + // Args to be passed to the shell command + private String taskArgs = ""; + // Location of shell script ( obtained from info set in env ) + // Shell script path in fs + private String scriptPath = ""; + private String input_path = ""; + private String output_path = ""; + private String dockerImage = ""; + private String registryQuorum = ""; + private boolean enableTensorBoard = false; + private TFContainerRequest defaultContainerRequest = new TFContainerRequest(1, 1024, 0); + + // chief = "worker[0]", ps = "ps[*]" + // Env variables to be setup for the shell command + private String psJobName = DEF_PS_JOB_NAME; + private String chiefJobName = DEF_CHIEF_JOB_NAME; + private int chiefIndex = DEF_CHIEF_INDEX; + + // map of job name => task + private Map jobTasks = new ConcurrentHashMap<>(); + + private boolean isDone = false; + private FinalApplicationStatus finalStatus = FinalApplicationStatus.UNDEFINED; + private String finalMessage = null; + + private TFSession(final TFSessionBuilder builder) { + appName = builder.appName; + appIDString = builder.appIDString; + taskCommand = builder.taskCmd; + taskArgs = builder.taskArgs; + scriptPath = builder.scriptPath; + input_path = builder.inputPath; + output_path = builder.outputPath; + dockerImage = builder.dockerImage; + registryQuorum = builder.registryQuorum; + enableTensorBoard = builder.enableTensorBoard; + defaultContainerRequest = new TFContainerRequest(builder.defaultContainerRequest); + + psJobName = DEF_PS_JOB_NAME; + chiefJobName = DEF_CHIEF_JOB_NAME; + chiefIndex = DEF_CHIEF_INDEX; + + for (Map.Entry entry : parseClusterRequirementString(builder.clusterReqString).entrySet()) { + String job_name = entry.getKey(); + Integer nTasks = entry.getValue(); + + // setup a task to hosts array, to keep track on which task needs + // container + Task[] tasks = new Task[nTasks]; + jobTasks.put(job_name, tasks); + } + if (enableTensorBoard) { + // setup additional task for tensorboard + jobTasks.put(DTF_TENSORBOARD_JOBNAME, new Task[1]); + } + } + + private LocalResource makeLocalResource(FileSystem fs, Path dst, LocalResourceType type) throws IOException { + FileStatus scFileStatus = fs.getFileStatus(dst); + return LocalResource.newInstance( + ConverterUtils.getYarnUrlFromURI(dst.toUri()), type, LocalResourceVisibility.APPLICATION, + scFileStatus.getLen(), scFileStatus.getModificationTime()); + } + + synchronized ArrayList getRequiredContainers() { + ArrayList requests = new ArrayList<>(); + + for (Map.Entry entry : jobTasks.entrySet()) { + Task[] tasks = entry.getValue(); + for (Task task : tasks) { + if (task == null) { + // current, all tasks are assume to have the same container request + // potential enhancement for a job and/or task specific request + TFContainerRequest request = new TFContainerRequest(defaultContainerRequest); + requests.add(request); + } + } + } + + return requests; + } + + private synchronized boolean checkAllReady() { + boolean ready = true; + + for (Map.Entry entry : jobTasks.entrySet()) { + Task[] tasks = entry.getValue(); + for (Task task : tasks) { + if (task == null || !task.isReady()) { + ready = false; + break; + } + } + } + + if (ready) { + notifyAll(); + } + + return ready; + } + + private Task getTaskByContainerId(ContainerId id) { + for (Map.Entry entry : jobTasks.entrySet()) { + Task[] tasks = entry.getValue(); + for (Task task : tasks) { + ContainerId containerId = task.getContainerId(); + if (containerId != null && containerId.equals(id)) { + return task; + } + } + } + + return null; + } + + String getTaskStarterCommand() { + // get executable command + return String.format("( PYTHONPATH=%s:$PYTHONPATH python -m %s.__main__ --debug )", TaskStarterPath, TaskStarterModule); + } + + String getJobAndIndex(ContainerId containerId) { + for (Map.Entry entry : jobTasks.entrySet()) { + Task[] tasks = entry.getValue(); + for (Task task : tasks) { + if (task != null && task.getContainerId().equals(containerId)) { + return String.format("%s[%d]", task.getJobName(), task.getIndex()); + } + } + } + + return "unknown[#]"; + } + + Map getClusterSpec() { + Map map = new HashMap<>(); + + for (Map.Entry entry : jobTasks.entrySet()) { + String jobName = entry.getKey(); + Task[] tasks = entry.getValue(); + + boolean first = true; + StringBuilder builder = new StringBuilder(); + for (Task task : tasks) { + if (task == null) { + continue; + } + + String hostPort = task.getHostPort(); + if (!first) + builder.append(","); + first = false; + builder.append(hostPort); + } + String jobNameEnv = convertJobNameEnv(jobName); + map.put(jobNameEnv, builder.toString()); + } + + return map; + } + + synchronized boolean updateAllocatedPort(String jobName, int index, int port) { + Task[] tasks = jobTasks.get(jobName); + + if (tasks != null && index >= 0 && index < tasks.length && tasks[index] != null) + tasks[index].setPort(port); + return checkAllReady(); + } + + void setAppGlobalEnv(Map shellEnv) { + shellEnv.put(DTF_SERVICE_CLASS, REGISTRY_SERVICE_CLASS); + shellEnv.put(DTF_APPLICATION_ID, appIDString); + shellEnv.put(DTF_ZK_HOSTS, registryQuorum); + + if (!input_path.isEmpty()) { + shellEnv.put(DTF_INPUT_PATH, input_path); + } + if (!output_path.isEmpty()) { + shellEnv.put(DTF_OUTPUT_PATH, output_path); + } + if (!dockerImage.isEmpty()) { + shellEnv.put(DTF_DOCKER_IMAGE, dockerImage); + } + if (!scriptPath.isEmpty()) { + shellEnv.put(DTF_TASK_SCRIPT, TaskScriptFile); + } + shellEnv.put(DTF_TASK_PROGRAM, TaskProgramFile); + } + + private String getFsBaseDir() { + return appName + "/" + appIDString; + } + + private void writeBytesToFs(FileSystem fs, Path dst, String content) throws IOException { + FSDataOutputStream ostream = null; + try { + ostream = FileSystem.create(fs, dst, new FsPermission((short) 0710)); + ostream.writeBytes(content); + } finally { + IOUtils.closeQuietly(ostream); + } + } + + private void writeBytesToFs(FileSystem fs, Path dst, InputStream content) throws IOException { + FSDataOutputStream ostream = null; + try { + ostream = FileSystem.create(fs, dst, new FsPermission((short) 0710)); + IOUtils.copy(content, ostream); + } finally { + IOUtils.closeQuietly(ostream); + } + } + + private Path writeTaskStarterToFs(FileSystem fs) throws IOException { + ClassLoader clsLoader = getClass().getClassLoader(); + InputStream inStream = clsLoader.getResourceAsStream(TaskStarterResource); + + String baseDir = getFsBaseDir() + "/" + TaskStarterResource; + Path dst = new Path(fs.getHomeDirectory(), baseDir); + + writeBytesToFs(fs, dst, inStream); + + return dst; + } + + private Path writeTensorBoardProgramToFs(FileSystem fs) throws IOException { + String taskProgramText; + String hostsEnvVar = convertJobNameEnv(DTF_TENSORBOARD_JOBNAME); + taskProgramText = "PORT=$(echo ${" + hostsEnvVar + "} | sed 's#.*:##'); tensorboard --port ${PORT} --logdir ${DTF_OUTPUT_PATH}"; + + String baseDir = getFsBaseDir() + "/" + TensorBoardProgramFile; + Path dst = new Path(fs.getHomeDirectory(), baseDir); + + writeBytesToFs(fs, dst, taskProgramText); + + return dst; + } + + void createResources(FileSystem fs, Map localResources) throws IOException { + Path dst; + LocalResource scRsrc; + + // add taskStarter to localResources + dst = writeTaskStarterToFs(fs); + scRsrc = makeLocalResource(fs, dst, LocalResourceType.ARCHIVE); + localResources.put(TaskStarterPath, scRsrc); + + // add script to localResources + if (!scriptPath.isEmpty()) { + dst = new Path(scriptPath); + scRsrc = makeLocalResource(fs, dst, LocalResourceType.FILE); + localResources.put(TaskScriptFile, scRsrc); + } + + dst = writeTaskProgramToFs(fs); + scRsrc = makeLocalResource(fs, dst, LocalResourceType.FILE); + localResources.put(TaskProgramFile, scRsrc); + + if (enableTensorBoard) { + dst = writeTensorBoardProgramToFs(fs); + scRsrc = makeLocalResource(fs, dst, LocalResourceType.FILE); + localResources.put(TensorBoardProgramFile, scRsrc); + } + } + + private Path writeTaskProgramToFs(FileSystem fs) throws IOException { + String taskProgramText; + if (!taskCommand.isEmpty()) { + // command args + taskProgramText = taskCommand + " " + taskArgs; + } else if (!scriptPath.isEmpty()) { + // bash /c ExecScript.sh args... + taskProgramText = linux_bash_command + " " + TaskScriptFile + " " + taskArgs; + } else { + // error, both script and shellCommand are empty + throw new IllegalArgumentException("No task command nor task script provided"); + } + + String baseDir = getFsBaseDir() + "/" + TaskProgramFile; + Path dst = new Path(fs.getHomeDirectory(), baseDir); + + writeBytesToFs(fs, dst, taskProgramText); + + return dst; + } + + private static Map parseClusterRequirementString(String clusterReqString) { + Map map = new ConcurrentHashMap<>(); + + String[] jobs = clusterReqString.split(","); + for (String jobName : jobs) { + String[] job = jobName.split(":"); + if (job.length != 2) { + throw new IllegalArgumentException( + "Invalid cluster requirement string <" + clusterReqString + ">"); + } + map.put(job[0], Integer.valueOf(job[1])); + } + return map; + } + + boolean isDone() { + if (isDone) + return true; + + int failedCnt = 0; + + // check + for (Map.Entry entry : jobTasks.entrySet()) { + String jobName = entry.getKey(); + Task[] tasks = entry.getValue(); + + if (jobName.equals(psJobName) || jobName.equals(DTF_TENSORBOARD_JOBNAME)) { + // ignore PS and TB job + continue; + } + + for (Task task : tasks) { + if (task == null) { + LOG.info("TF: task is not started yet, isDone=false"); + return false; + } + boolean isCompleted = task.isCompleted(); + if (!isCompleted) { + LOG.info("TF: task=" + task + ", is not completed yet, isDone=false"); + return false; + } + + int exitStatus = task.getExitStatus(); + if (exitStatus != 0) { + failedCnt++; + } + } + } + + isDone = true; + if (failedCnt > 0) { + setFinalStatus(FinalApplicationStatus.FAILED, + "At least one job task exited with non-zero status, failedCnt=" + + failedCnt); + } else { + setFinalStatus(FinalApplicationStatus.SUCCEEDED, null); + } + return isDone; + } + + private TaskType getTaskType(Task task) { + TaskType type; + + int index = task.getIndex(); + String jobName = task.getJobName(); + + if (index == chiefIndex && jobName.equals(chiefJobName)) + type = TaskType.TASK_TYPE_CHIEF; + else if (jobName.equals(psJobName)) + type = TaskType.TASK_TYPE_PARAMETER_SERVER; + else + type = TaskType.TASK_TYPE_OTHERS; + + return type; + } + + boolean handleContainerTaskCompleted(ContainerId conainterId, + int exitStatus) { + Task task = getTaskByContainerId(conainterId); + if (task == null) { + return false; + } + + TaskType taskType = getTaskType(task); + + LOG.info("TF: handleContainerTaskCompleted(): container=" + task.containerId + ", exitStatus=" + exitStatus + + ", taskType=" + taskType); + + task.setExitStatus(exitStatus); + + switch (taskType) { + case TASK_TYPE_CHIEF: + case TASK_TYPE_OTHERS: + if (exitStatus != 0) { + isDone = true; + setFinalStatus(FinalApplicationStatus.FAILED, + "Failed: a worker task exited with exitStatus=" + exitStatus + ", exiting application"); + } + break; + case TASK_TYPE_PARAMETER_SERVER: + break; + default: + // not a TF task + break; + } // END of switch(taskType) + + LOG.info("TF: container=" + task.containerId + ", isDone=" + isDone + + ", finalStatus=" + finalStatus + + ", finalMessage=" + (finalMessage != null ? finalMessage : "null")); + + return isDone; + } + + + private void setFinalStatus(FinalApplicationStatus status, + String message) { + finalStatus = status; + finalMessage = message; + } + + String getFinalMessage() { + return finalMessage; + } + + FinalApplicationStatus getFinalStatus() { + return finalStatus; + } + + private String convertJobNameEnv(String jobName) { + return String.format(DTF_JOBNAME_HOSTS_FORMAT, jobName.toUpperCase()); + } + + synchronized boolean addContainer(Container container, + Map myShellEnv) { + String host = container.getNodeId().getHost(); + + // find a job+task to assign the container host + for (Map.Entry entry : jobTasks.entrySet()) { + String jobName = entry.getKey(); + Task[] tasks = entry.getValue(); + + for (int i = 0; i < tasks.length; i++) { + if (tasks[i] == null) { + tasks[i] = new Task(jobName, host, i, container.getId()); + myShellEnv.put(DTF_TASK_JOB_NAME, jobName); + myShellEnv.put(DTF_TASK_INDEX, String.valueOf(i)); + + if (jobName.equals(DTF_TENSORBOARD_JOBNAME)) { + // overwrite task program with TensorBoard executable + myShellEnv.put(DTF_TASK_PROGRAM, TensorBoardProgramFile); + } + + return true; + } + } + } + + return false; + } + + void printHtmlStatusTable(PrintWriter out) { + + if (jobTasks == null) { + out.println("Information not available yet, try again later"); + return; + } + + out.println(""); + out.println(""); + for (Map.Entry entry : jobTasks.entrySet()) { + String jobName = entry.getKey(); + Task[] tasks = entry.getValue(); + + out.println(""); + out.println(""); + for (int i = 0; i < tasks.length; i++) { + Task task = tasks[i]; + + if (i != 0) + out.println(""); + out.println(""); + if (jobName.equals(DTF_TENSORBOARD_JOBNAME) && task.isReady()) { + String hostPort = task.getHostPort(); + out.println(String.format(""); + } + out.println(""); + out.println(""); + out.println(""); + } + out.println(""); + } + out.println("
Job NameIndexHostPortContainerIDexitStatus
" + jobName + "
" + task.getIndex() + "%s", + hostPort, hostPort)); + } else { + out.println("" + task.getHostPort() + "" + task.getContainerId() + "" + task.getExitStatus() + "
"); + + } + + public static class Task { + static private final String FORMAT_HOST_PORT = "%s:%d"; + + private String jobName = ""; + private int taskIndex = -1; + private String host = ""; + private int port = -1; + ContainerId containerId = null; + boolean completed = false; + int exitStatus = -1; + + Task(String name, String host, int index, ContainerId id) { + this.jobName = name; + this.host = host; + this.taskIndex = index; + this.containerId = id; + } + + public String toString() { + return String.format("Task:%s[%d]/%s:%d/%s/%d", jobName, taskIndex, host, port, containerId, exitStatus); + } + + String getJobName() { + return this.jobName; + } + + int getIndex() { + return this.taskIndex; + } + + String getHostPort() { + return String.format(FORMAT_HOST_PORT, host, port < 0 ? 0 : port); + } + + void setPort(int port) { + this.port = port; + } + + ContainerId getContainerId() { + return this.containerId; + } + + void setExitStatus(int status) { + this.completed = true; + this.exitStatus = status; + } + + boolean isCompleted() { + return this.completed; + } + + int getExitStatus() { + return this.exitStatus; + } + + boolean isReady() { + return (!host.isEmpty() && port > 0); + } + } + + static class TFSessionBuilder { + private String clusterReqString; + private String appIDString; + private String taskCmd; + private String taskArgs; + private boolean enableTensorBoard; + private String scriptPath; + private String inputPath; + private String outputPath; + private String dockerImage; + private String registryQuorum; + private String appName; + private TFContainerRequest defaultContainerRequest; + + TFSession build() { + return new TFSession(this); + } + + TFSessionBuilder setClusterReqString(String clusterReqString) { + this.clusterReqString = clusterReqString; + return this; + } + + TFSessionBuilder setAppIDString(String appIDString) { + this.appIDString = appIDString; + return this; + } + + TFSessionBuilder setTaskCmd(String taskCmd) { + this.taskCmd = taskCmd; + return this; + } + + TFSessionBuilder setTaskArgs(String taskArgs) { + this.taskArgs = taskArgs; + return this; + } + + TFSessionBuilder setEnableTensorBoard(boolean enableTensorBoard) { + this.enableTensorBoard = enableTensorBoard; + return this; + } + + TFSessionBuilder setScriptPath(String scriptPath) { + this.scriptPath = scriptPath; + return this; + } + + TFSessionBuilder setInputPath(String inputPath) { + this.inputPath = inputPath; + return this; + } + + TFSessionBuilder setOutputPath(String outputPath) { + this.outputPath = outputPath; + return this; + } + + TFSessionBuilder setDockerImage(String dockerImage) { + this.dockerImage = dockerImage; + return this; + } + + TFSessionBuilder setRegistryQuorum(String registryQuorum) { + this.registryQuorum = registryQuorum; + return this; + } + + TFSessionBuilder setAppName(String appName) { + this.appName = appName; + return this; + } + + TFSessionBuilder setDefaultContainerRequest(TFContainerRequest defaultRequest) { + this.defaultContainerRequest = defaultRequest; + return this; + } + + } +} diff --git a/yarn/src/main/python/task-starter/kazoo/__init__.py b/yarn/src/main/python/task-starter/kazoo/__init__.py new file mode 100644 index 00000000..792d6005 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/__init__.py @@ -0,0 +1 @@ +# diff --git a/yarn/src/main/python/task-starter/kazoo/client.py b/yarn/src/main/python/task-starter/kazoo/client.py new file mode 100644 index 00000000..27fbe003 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/client.py @@ -0,0 +1,1561 @@ +"""Kazoo Zookeeper Client""" +import inspect +import logging +import re +import warnings +from collections import defaultdict, deque +from functools import partial +from os.path import split + +import six + +from kazoo.exceptions import ( + AuthFailedError, + ConfigurationError, + ConnectionClosedError, + ConnectionLoss, + KazooException, + NoNodeError, + NodeExistsError, + SessionExpiredError, + WriterNotClosedException, +) +from kazoo.handlers.threading import SequentialThreadingHandler +from kazoo.handlers.utils import capture_exceptions, wrap +from kazoo.hosts import collect_hosts +from kazoo.loggingsupport import BLATHER +from kazoo.protocol.connection import ConnectionHandler +from kazoo.protocol.paths import normpath +from kazoo.protocol.paths import _prefix_root +from kazoo.protocol.serialization import ( + Auth, + CheckVersion, + CloseInstance, + Create, + Delete, + Exists, + GetChildren, + GetChildren2, + GetACL, + SetACL, + GetData, + Reconfig, + SetData, + Sync, + Transaction +) +from kazoo.protocol.states import KazooState +from kazoo.protocol.states import KeeperState +from kazoo.retry import KazooRetry +from kazoo.security import ACL +from kazoo.security import OPEN_ACL_UNSAFE + +# convenience API +from kazoo.recipe.barrier import Barrier +from kazoo.recipe.barrier import DoubleBarrier +from kazoo.recipe.counter import Counter +from kazoo.recipe.election import Election +from kazoo.recipe.lease import NonBlockingLease +from kazoo.recipe.lease import MultiNonBlockingLease +from kazoo.recipe.lock import Lock +from kazoo.recipe.lock import Semaphore +from kazoo.recipe.partitioner import SetPartitioner +from kazoo.recipe.party import Party +from kazoo.recipe.party import ShallowParty +from kazoo.recipe.queue import Queue +from kazoo.recipe.queue import LockingQueue +from kazoo.recipe.watchers import ChildrenWatch +from kazoo.recipe.watchers import DataWatch + +string_types = six.string_types +bytes_types = (six.binary_type,) + +LOST_STATES = (KeeperState.EXPIRED_SESSION, KeeperState.AUTH_FAILED, + KeeperState.CLOSED) +ENVI_VERSION = re.compile('([\d\.]*).*', re.DOTALL) +ENVI_VERSION_KEY = 'zookeeper.version' +log = logging.getLogger(__name__) + + +_RETRY_COMPAT_DEFAULTS = dict( + max_retries=None, + retry_delay=0.1, + retry_backoff=2, + retry_jitter=0.8, + retry_max_delay=3600, +) + +_RETRY_COMPAT_MAPPING = dict( + max_retries='max_tries', + retry_delay='delay', + retry_backoff='backoff', + retry_jitter='max_jitter', + retry_max_delay='max_delay', +) + + +class KazooClient(object): + """An Apache Zookeeper Python client supporting alternate callback + handlers and high-level functionality. + + Watch functions registered with this class will not get session + events, unlike the default Zookeeper watches. They will also be + called with a single argument, a + :class:`~kazoo.protocol.states.WatchedEvent` instance. + + """ + def __init__(self, hosts='127.0.0.1:2181', + timeout=10.0, client_id=None, handler=None, + default_acl=None, auth_data=None, read_only=None, + randomize_hosts=True, connection_retry=None, + command_retry=None, logger=None, **kwargs): + """Create a :class:`KazooClient` instance. All time arguments + are in seconds. + + :param hosts: Comma-separated list of hosts to connect to + (e.g. 127.0.0.1:2181,127.0.0.1:2182,[::1]:2183). + :param timeout: The longest to wait for a Zookeeper connection. + :param client_id: A Zookeeper client id, used when + re-establishing a prior session connection. + :param handler: An instance of a class implementing the + :class:`~kazoo.interfaces.IHandler` interface + for callback handling. + :param default_acl: A default ACL used on node creation. + :param auth_data: + A list of authentication credentials to use for the + connection. Should be a list of (scheme, credential) + tuples as :meth:`add_auth` takes. + :param read_only: Allow connections to read only servers. + :param randomize_hosts: By default randomize host selection. + :param connection_retry: + A :class:`kazoo.retry.KazooRetry` object to use for + retrying the connection to Zookeeper. Also can be a dict of + options which will be used for creating one. + :param command_retry: + A :class:`kazoo.retry.KazooRetry` object to use for + the :meth:`KazooClient.retry` method. Also can be a dict of + options which will be used for creating one. + :param logger: A custom logger to use instead of the module + global `log` instance. + + Basic Example: + + .. code-block:: python + + zk = KazooClient() + zk.start() + children = zk.get_children('/') + zk.stop() + + As a convenience all recipe classes are available as attributes + and get automatically bound to the client. For example:: + + zk = KazooClient() + zk.start() + lock = zk.Lock('/lock_path') + + .. versionadded:: 0.6 + The read_only option. Requires Zookeeper 3.4+ + + .. versionadded:: 0.6 + The retry_max_delay option. + + .. versionadded:: 0.6 + The randomize_hosts option. + + .. versionchanged:: 0.8 + Removed the unused watcher argument (was second argument). + + .. versionadded:: 1.2 + The connection_retry, command_retry and logger options. + + """ + self.logger = logger or log + + # Record the handler strategy used + self.handler = handler if handler else SequentialThreadingHandler() + if inspect.isclass(self.handler): + raise ConfigurationError("Handler must be an instance of a class, " + "not the class: %s" % self.handler) + + self.auth_data = auth_data if auth_data else set([]) + self.default_acl = default_acl + self.randomize_hosts = randomize_hosts + self.hosts = None + self.chroot = None + self.set_hosts(hosts) + + # Curator like simplified state tracking, and listeners for + # state transitions + self._state = KeeperState.CLOSED + self.state = KazooState.LOST + self.state_listeners = set() + + self._reset() + self.read_only = read_only + + if client_id: + self._session_id = client_id[0] + self._session_passwd = client_id[1] + else: + self._reset_session() + + # ZK uses milliseconds + self._session_timeout = int(timeout * 1000) + + # We use events like twitter's client to track current and + # desired state (connected, and whether to shutdown) + self._live = self.handler.event_object() + self._writer_stopped = self.handler.event_object() + self._stopped = self.handler.event_object() + self._stopped.set() + self._writer_stopped.set() + + self.retry = self._conn_retry = None + + if type(connection_retry) is dict: + self._conn_retry = KazooRetry(**connection_retry) + elif type(connection_retry) is KazooRetry: + self._conn_retry = connection_retry + + if type(command_retry) is dict: + self.retry = KazooRetry(**command_retry) + elif type(command_retry) is KazooRetry: + self.retry = command_retry + + if type(self._conn_retry) is KazooRetry: + if self.handler.sleep_func != self._conn_retry.sleep_func: + raise ConfigurationError("Retry handler and event handler " + " must use the same sleep func") + + if type(self.retry) is KazooRetry: + if self.handler.sleep_func != self.retry.sleep_func: + raise ConfigurationError( + "Command retry handler and event handler " + "must use the same sleep func") + + if self.retry is None or self._conn_retry is None: + old_retry_keys = dict(_RETRY_COMPAT_DEFAULTS) + for key in old_retry_keys: + try: + old_retry_keys[key] = kwargs.pop(key) + warnings.warn( + 'Passing retry configuration param %s to the ' + 'client directly is deprecated, please pass a ' + 'configured retry object (using param %s)' % ( + key, _RETRY_COMPAT_MAPPING[key]), + DeprecationWarning, stacklevel=2) + except KeyError: + pass + + retry_keys = {} + for oldname, value in old_retry_keys.items(): + retry_keys[_RETRY_COMPAT_MAPPING[oldname]] = value + + if self._conn_retry is None: + self._conn_retry = KazooRetry( + sleep_func=self.handler.sleep_func, + **retry_keys) + if self.retry is None: + self.retry = KazooRetry( + sleep_func=self.handler.sleep_func, + **retry_keys) + + self._conn_retry.interrupt = lambda: self._stopped.is_set() + self._connection = ConnectionHandler( + self, self._conn_retry.copy(), logger=self.logger) + + # Every retry call should have its own copy of the retry helper + # to avoid shared retry counts + self._retry = self.retry + + def _retry(*args, **kwargs): + return self._retry.copy()(*args, **kwargs) + self.retry = _retry + + self.Barrier = partial(Barrier, self) + self.Counter = partial(Counter, self) + self.DoubleBarrier = partial(DoubleBarrier, self) + self.ChildrenWatch = partial(ChildrenWatch, self) + self.DataWatch = partial(DataWatch, self) + self.Election = partial(Election, self) + self.NonBlockingLease = partial(NonBlockingLease, self) + self.MultiNonBlockingLease = partial(MultiNonBlockingLease, self) + self.Lock = partial(Lock, self) + self.Party = partial(Party, self) + self.Queue = partial(Queue, self) + self.LockingQueue = partial(LockingQueue, self) + self.SetPartitioner = partial(SetPartitioner, self) + self.Semaphore = partial(Semaphore, self) + self.ShallowParty = partial(ShallowParty, self) + + # If we got any unhandled keywords, complain like Python would + if kwargs: + raise TypeError('__init__() got unexpected keyword arguments: %s' + % (kwargs.keys(),)) + + def _reset(self): + """Resets a variety of client states for a new connection.""" + self._queue = deque() + self._pending = deque() + + self._reset_watchers() + self._reset_session() + self.last_zxid = 0 + self._protocol_version = None + + def _reset_watchers(self): + self._child_watchers = defaultdict(set) + self._data_watchers = defaultdict(set) + + def _reset_session(self): + self._session_id = None + self._session_passwd = b'\x00' * 16 + + @property + def client_state(self): + """Returns the last Zookeeper client state + + This is the non-simplified state information and is generally + not as useful as the simplified KazooState information. + + """ + return self._state + + @property + def client_id(self): + """Returns the client id for this Zookeeper session if + connected. + + :returns: client id which consists of the session id and + password. + :rtype: tuple + """ + if self._live.is_set(): + return (self._session_id, self._session_passwd) + return None + + @property + def connected(self): + """Returns whether the Zookeeper connection has been + established.""" + return self._live.is_set() + + def set_hosts(self, hosts, randomize_hosts=None): + """ sets the list of hosts used by this client. + + This function accepts the same format hosts parameter as the init + function and sets the client to use the new hosts the next time it + needs to look up a set of hosts. This function does not affect the + current connected status. + + It is not currently possible to change the chroot with this function, + setting a host list with a new chroot will raise a ConfigurationError. + + :param hosts: see description in :meth:`KazooClient.__init__` + :param randomize_hosts: override client default for host randomization + :raises: + :exc:`ConfigurationError` if the hosts argument changes the chroot + + .. versionadded:: 1.4 + + .. warning:: + + Using this function to point a client to a completely disparate + zookeeper server cluster has undefined behavior. + + """ + + if randomize_hosts is None: + randomize_hosts = self.randomize_hosts + + self.hosts, chroot = collect_hosts(hosts, randomize_hosts) + + if chroot: + new_chroot = normpath(chroot) + else: + new_chroot = '' + + if self.chroot is not None and new_chroot != self.chroot: + raise ConfigurationError("Changing chroot at runtime is not " + "currently supported") + + self.chroot = new_chroot + + def add_listener(self, listener): + """Add a function to be called for connection state changes. + + This function will be called with a + :class:`~kazoo.protocol.states.KazooState` instance indicating + the new connection state on state transitions. + + .. warning:: + + This function must not block. If its at all likely that it + might need data or a value that could result in blocking + than the :meth:`~kazoo.interfaces.IHandler.spawn` method + should be used so that the listener can return immediately. + + """ + if not (listener and callable(listener)): + raise ConfigurationError("listener must be callable") + self.state_listeners.add(listener) + + def remove_listener(self, listener): + """Remove a listener function""" + self.state_listeners.discard(listener) + + def _make_state_change(self, state): + # skip if state is current + if self.state == state: + return + + self.state = state + + # Create copy of listeners for iteration in case one needs to + # remove itself + for listener in list(self.state_listeners): + try: + remove = listener(state) + if remove is True: + self.remove_listener(listener) + except Exception: + self.logger.exception("Error in connection state listener") + + def _session_callback(self, state): + if state == self._state: + return + + # Note that we don't check self.state == LOST since that's also + # the client's initial state + dead_state = self._state in LOST_STATES + self._state = state + + # If we were previously closed or had an expired session, and + # are now connecting, don't bother with the rest of the + # transitions since they only apply after + # we've established a connection + if dead_state and state == KeeperState.CONNECTING: + self.logger.log(BLATHER, "Skipping state change") + return + + if state in (KeeperState.CONNECTED, KeeperState.CONNECTED_RO): + self.logger.info("Zookeeper connection established, " + "state: %s", state) + self._live.set() + self._make_state_change(KazooState.CONNECTED) + elif state in LOST_STATES: + self.logger.info("Zookeeper session lost, state: %s", state) + self._live.clear() + self._make_state_change(KazooState.LOST) + self._notify_pending(state) + self._reset() + else: + self.logger.info("Zookeeper connection lost") + # Connection lost + self._live.clear() + self._notify_pending(state) + self._make_state_change(KazooState.SUSPENDED) + self._reset_watchers() + + def _notify_pending(self, state): + """Used to clear a pending response queue and request queue + during connection drops.""" + if state == KeeperState.AUTH_FAILED: + exc = AuthFailedError() + elif state == KeeperState.EXPIRED_SESSION: + exc = SessionExpiredError() + else: + exc = ConnectionLoss() + + while True: + try: + request, async_object, xid = self._pending.popleft() + if async_object: + async_object.set_exception(exc) + except IndexError: + break + + while True: + try: + request, async_object = self._queue.popleft() + if async_object: + async_object.set_exception(exc) + except IndexError: + break + + def _safe_close(self): + self.handler.stop() + timeout = self._session_timeout // 1000 + if timeout < 10: + timeout = 10 + if not self._connection.stop(timeout): + raise WriterNotClosedException( + "Writer still open from prior connection " + "and wouldn't close after %s seconds" % timeout) + + def _call(self, request, async_object): + """Ensure there's an active connection and put the request in + the queue if there is. + + Returns False if the call short circuits due to AUTH_FAILED, + CLOSED, EXPIRED_SESSION or CONNECTING state. + + """ + + if self._state == KeeperState.AUTH_FAILED: + async_object.set_exception(AuthFailedError()) + return False + elif self._state == KeeperState.CLOSED: + async_object.set_exception(ConnectionClosedError( + "Connection has been closed")) + return False + elif self._state in (KeeperState.EXPIRED_SESSION, + KeeperState.CONNECTING): + async_object.set_exception(SessionExpiredError()) + return False + + self._queue.append((request, async_object)) + + # wake the connection, guarding against a race with close() + write_sock = self._connection._write_sock + if write_sock is None: + async_object.set_exception(ConnectionClosedError( + "Connection has been closed")) + try: + write_sock.send(b'\0') + except: + async_object.set_exception(ConnectionClosedError( + "Connection has been closed")) + + def start(self, timeout=15): + """Initiate connection to ZK. + + :param timeout: Time in seconds to wait for connection to + succeed. + :raises: :attr:`~kazoo.interfaces.IHandler.timeout_exception` + if the connection wasn't established within `timeout` + seconds. + + """ + event = self.start_async() + event.wait(timeout=timeout) + if not self.connected: + # We time-out, ensure we are disconnected + self.stop() + raise self.handler.timeout_exception("Connection time-out") + + if self.chroot and not self.exists("/"): + warnings.warn("No chroot path exists, the chroot path " + "should be created before normal use.") + + def start_async(self): + """Asynchronously initiate connection to ZK. + + :returns: An event object that can be checked to see if the + connection is alive. + :rtype: :class:`~threading.Event` compatible object. + + """ + # If we're already connected, ignore + if self._live.is_set(): + return self._live + + # Make sure we're safely closed + self._safe_close() + + # We've been asked to connect, clear the stop and our writer + # thread indicator + self._stopped.clear() + self._writer_stopped.clear() + + # Start the handler + self.handler.start() + + # Start the connection + self._connection.start() + return self._live + + def stop(self): + """Gracefully stop this Zookeeper session. + + This method can be called while a reconnection attempt is in + progress, which will then be halted. + + Once the connection is closed, its session becomes invalid. All + the ephemeral nodes in the ZooKeeper server associated with the + session will be removed. The watches left on those nodes (and + on their parents) will be triggered. + + """ + if self._stopped.is_set(): + return + + self._stopped.set() + self._queue.append((CloseInstance, None)) + self._connection._write_sock.send(b'\0') + self._safe_close() + + def restart(self): + """Stop and restart the Zookeeper session.""" + self.stop() + self.start() + + def close(self): + """Free any resources held by the client. + + This method should be called on a stopped client before it is + discarded. Not doing so may result in filehandles being leaked. + + .. versionadded:: 1.0 + """ + self._connection.close() + + def command(self, cmd=b'ruok'): + """Sent a management command to the current ZK server. + + Examples are `ruok`, `envi` or `stat`. + + :returns: An unstructured textual response. + :rtype: str + + :raises: + :exc:`ConnectionLoss` if there is no connection open, or + possibly a :exc:`socket.error` if there's a problem with + the connection used just for this command. + + .. versionadded:: 0.5 + + """ + if not self._live.is_set(): + raise ConnectionLoss("No connection to server") + + peer = self._connection._socket.getpeername()[:2] + sock = self.handler.create_connection( + peer, timeout=self._session_timeout / 1000.0) + sock.sendall(cmd) + result = sock.recv(8192) + sock.close() + return result.decode('utf-8', 'replace') + + def server_version(self, retries=3): + """Get the version of the currently connected ZK server. + + :returns: The server version, for example (3, 4, 3). + :rtype: tuple + + .. versionadded:: 0.5 + + """ + def _try_fetch(): + data = self.command(b'envi') + data_parsed = {} + for line in data.splitlines(): + try: + k, v = line.split("=", 1) + k = k.strip() + v = v.strip() + except ValueError: + pass + else: + if k: + data_parsed[k] = v + version = data_parsed.get(ENVI_VERSION_KEY, '') + version_digits = ENVI_VERSION.match(version).group(1) + try: + return tuple([int(d) for d in version_digits.split('.')]) + except ValueError: + return None + + def _is_valid(version): + # All zookeeper versions should have at least major.minor + # version numbers; if we get one that doesn't it is likely not + # correct and was truncated... + if version and len(version) > 1: + return True + return False + + # Try 1 + retries amount of times to get a version that we know + # will likely be acceptable... + version = _try_fetch() + if _is_valid(version): + return version + for _i in six.moves.range(0, retries): + version = _try_fetch() + if _is_valid(version): + return version + raise KazooException("Unable to fetch useable server" + " version after trying %s times" + % (1 + max(0, retries))) + + def add_auth(self, scheme, credential): + """Send credentials to server. + + :param scheme: authentication scheme (default supported: + "digest"). + :param credential: the credential -- value depends on scheme. + + :returns: True if it was successful. + :rtype: bool + + :raises: + :exc:`~kazoo.exceptions.AuthFailedError` if it failed though + the session state will be set to AUTH_FAILED as well. + + """ + return self.add_auth_async(scheme, credential).get() + + def add_auth_async(self, scheme, credential): + """Asynchronously send credentials to server. Takes the same + arguments as :meth:`add_auth`. + + :rtype: :class:`~kazoo.interfaces.IAsyncResult` + + """ + if not isinstance(scheme, string_types): + raise TypeError("Invalid type for 'scheme' (string expected)") + if not isinstance(credential, string_types): + raise TypeError("Invalid type for 'credential' (string expected)") + + # we need this auth data to re-authenticate on reconnect + self.auth_data.add((scheme, credential)) + + async_result = self.handler.async_result() + self._call(Auth(0, scheme, credential), async_result) + return async_result + + def unchroot(self, path): + """Strip the chroot if applicable from the path.""" + if not self.chroot: + return path + + if path.startswith(self.chroot): + return path[len(self.chroot):] + else: + return path + + def sync_async(self, path): + """Asynchronous sync. + + :rtype: :class:`~kazoo.interfaces.IAsyncResult` + + """ + async_result = self.handler.async_result() + self._call(Sync(_prefix_root(self.chroot, path)), async_result) + return async_result + + def sync(self, path): + """Sync, blocks until response is acknowledged. + + Flushes channel between process and leader. + + :param path: path of node. + :returns: The node path that was synced. + :raises: + :exc:`~kazoo.exceptions.ZookeeperError` if the server + returns a non-zero error code. + + .. versionadded:: 0.5 + + """ + return self.sync_async(path).get() + + def create(self, path, value=b"", acl=None, ephemeral=False, + sequence=False, makepath=False): + """Create a node with the given value as its data. Optionally + set an ACL on the node. + + The ephemeral and sequence arguments determine the type of the + node. + + An ephemeral node will be automatically removed by ZooKeeper + when the session associated with the creation of the node + expires. + + A sequential node will be given the specified path plus a + suffix `i` where i is the current sequential number of the + node. The sequence number is always fixed length of 10 digits, + 0 padded. Once such a node is created, the sequential number + will be incremented by one. + + If a node with the same actual path already exists in + ZooKeeper, a NodeExistsError will be raised. Note that since a + different actual path is used for each invocation of creating + sequential nodes with the same path argument, the call will + never raise NodeExistsError. + + If the parent node does not exist in ZooKeeper, a NoNodeError + will be raised. Setting the optional `makepath` argument to + `True` will create all missing parent nodes instead. + + An ephemeral node cannot have children. If the parent node of + the given path is ephemeral, a NoChildrenForEphemeralsError + will be raised. + + This operation, if successful, will trigger all the watches + left on the node of the given path by :meth:`exists` and + :meth:`get` API calls, and the watches left on the parent node + by :meth:`get_children` API calls. + + The maximum allowable size of the node value is 1 MB. Values + larger than this will cause a ZookeeperError to be raised. + + :param path: Path of node. + :param value: Initial bytes value of node. + :param acl: :class:`~kazoo.security.ACL` list. + :param ephemeral: Boolean indicating whether node is ephemeral + (tied to this session). + :param sequence: Boolean indicating whether path is suffixed + with a unique index. + :param makepath: Whether the path should be created if it + doesn't exist. + :returns: Real path of the new node. + :rtype: str + + :raises: + :exc:`~kazoo.exceptions.NodeExistsError` if the node + already exists. + + :exc:`~kazoo.exceptions.NoNodeError` if parent nodes are + missing. + + :exc:`~kazoo.exceptions.NoChildrenForEphemeralsError` if + the parent node is an ephemeral node. + + :exc:`~kazoo.exceptions.ZookeeperError` if the provided + value is too large. + + :exc:`~kazoo.exceptions.ZookeeperError` if the server + returns a non-zero error code. + + """ + acl = acl or self.default_acl + return self.create_async(path, value, acl=acl, ephemeral=ephemeral, + sequence=sequence, makepath=makepath).get() + + def create_async(self, path, value=b"", acl=None, ephemeral=False, + sequence=False, makepath=False): + """Asynchronously create a ZNode. Takes the same arguments as + :meth:`create`. + + :rtype: :class:`~kazoo.interfaces.IAsyncResult` + + .. versionadded:: 1.1 + The makepath option. + + """ + if acl is None and self.default_acl: + acl = self.default_acl + + if not isinstance(path, string_types): + raise TypeError("Invalid type for 'path' (string expected)") + if acl and (isinstance(acl, ACL) or + not isinstance(acl, (tuple, list))): + raise TypeError("Invalid type for 'acl' (acl must be a tuple/list" + " of ACL's") + if value is not None and not isinstance(value, bytes_types): + raise TypeError("Invalid type for 'value' (must be a byte string)") + if not isinstance(ephemeral, bool): + raise TypeError("Invalid type for 'ephemeral' (bool expected)") + if not isinstance(sequence, bool): + raise TypeError("Invalid type for 'sequence' (bool expected)") + if not isinstance(makepath, bool): + raise TypeError("Invalid type for 'makepath' (bool expected)") + + flags = 0 + if ephemeral: + flags |= 1 + if sequence: + flags |= 2 + if acl is None: + acl = OPEN_ACL_UNSAFE + + async_result = self.handler.async_result() + + @capture_exceptions(async_result) + def do_create(): + result = self._create_async_inner( + path, value, acl, flags, trailing=sequence) + result.rawlink(create_completion) + + @capture_exceptions(async_result) + def retry_completion(result): + result.get() + do_create() + + @wrap(async_result) + def create_completion(result): + try: + return self.unchroot(result.get()) + except NoNodeError: + if not makepath: + raise + if sequence and path.endswith('/'): + parent = path.rstrip('/') + else: + parent, _ = split(path) + self.ensure_path_async(parent, acl).rawlink(retry_completion) + + do_create() + return async_result + + def _create_async_inner(self, path, value, acl, flags, trailing=False): + async_result = self.handler.async_result() + call_result = self._call( + Create(_prefix_root(self.chroot, path, trailing=trailing), + value, acl, flags), async_result) + if call_result is False: + # We hit a short-circuit exit on the _call. Because we are + # not using the original async_result here, we bubble the + # exception upwards to the do_create function in + # KazooClient.create so that it gets set on the correct + # async_result object + raise async_result.exception + return async_result + + def ensure_path(self, path, acl=None): + """Recursively create a path if it doesn't exist. + + :param path: Path of node. + :param acl: Permissions for node. + + """ + return self.ensure_path_async(path, acl).get() + + def ensure_path_async(self, path, acl=None): + """Recursively create a path asynchronously if it doesn't + exist. Takes the same arguments as :meth:`ensure_path`. + + :rtype: :class:`~kazoo.interfaces.IAsyncResult` + + .. versionadded:: 1.1 + + """ + acl = acl or self.default_acl + async_result = self.handler.async_result() + + @wrap(async_result) + def create_completion(result): + try: + return result.get() + except NodeExistsError: + return True + + @capture_exceptions(async_result) + def prepare_completion(next_path, result): + result.get() + self.create_async(next_path, acl=acl).rawlink(create_completion) + + @wrap(async_result) + def exists_completion(path, result): + if result.get(): + return True + parent, node = split(path) + if node: + self.ensure_path_async(parent, acl=acl).rawlink( + partial(prepare_completion, path)) + else: + self.create_async(path, acl=acl).rawlink(create_completion) + + self.exists_async(path).rawlink(partial(exists_completion, path)) + + return async_result + + def exists(self, path, watch=None): + """Check if a node exists. + + If a watch is provided, it will be left on the node with the + given path. The watch will be triggered by a successful + operation that creates/deletes the node or sets the data on the + node. + + :param path: Path of node. + :param watch: Optional watch callback to set for future changes + to this path. + :returns: ZnodeStat of the node if it exists, else None if the + node does not exist. + :rtype: :class:`~kazoo.protocol.states.ZnodeStat` or `None`. + + :raises: + :exc:`~kazoo.exceptions.ZookeeperError` if the server + returns a non-zero error code. + + """ + return self.exists_async(path, watch).get() + + def exists_async(self, path, watch=None): + """Asynchronously check if a node exists. Takes the same + arguments as :meth:`exists`. + + :rtype: :class:`~kazoo.interfaces.IAsyncResult` + + """ + if not isinstance(path, string_types): + raise TypeError("Invalid type for 'path' (string expected)") + if watch and not callable(watch): + raise TypeError("Invalid type for 'watch' (must be a callable)") + + async_result = self.handler.async_result() + self._call(Exists(_prefix_root(self.chroot, path), watch), + async_result) + return async_result + + def get(self, path, watch=None): + """Get the value of a node. + + If a watch is provided, it will be left on the node with the + given path. The watch will be triggered by a successful + operation that sets data on the node, or deletes the node. + + :param path: Path of node. + :param watch: Optional watch callback to set for future changes + to this path. + :returns: + Tuple (value, :class:`~kazoo.protocol.states.ZnodeStat`) of + node. + :rtype: tuple + + :raises: + :exc:`~kazoo.exceptions.NoNodeError` if the node doesn't + exist + + :exc:`~kazoo.exceptions.ZookeeperError` if the server + returns a non-zero error code + + """ + return self.get_async(path, watch).get() + + def get_async(self, path, watch=None): + """Asynchronously get the value of a node. Takes the same + arguments as :meth:`get`. + + :rtype: :class:`~kazoo.interfaces.IAsyncResult` + + """ + if not isinstance(path, string_types): + raise TypeError("Invalid type for 'path' (string expected)") + if watch and not callable(watch): + raise TypeError("Invalid type for 'watch' (must be a callable)") + + async_result = self.handler.async_result() + self._call(GetData(_prefix_root(self.chroot, path), watch), + async_result) + return async_result + + def get_children(self, path, watch=None, include_data=False): + """Get a list of child nodes of a path. + + If a watch is provided it will be left on the node with the + given path. The watch will be triggered by a successful + operation that deletes the node of the given path or + creates/deletes a child under the node. + + The list of children returned is not sorted and no guarantee is + provided as to its natural or lexical order. + + :param path: Path of node to list. + :param watch: Optional watch callback to set for future changes + to this path. + :param include_data: + Include the :class:`~kazoo.protocol.states.ZnodeStat` of + the node in addition to the children. This option changes + the return value to be a tuple of (children, stat). + + :returns: List of child node names, or tuple if `include_data` + is `True`. + :rtype: list + + :raises: + :exc:`~kazoo.exceptions.NoNodeError` if the node doesn't + exist. + + :exc:`~kazoo.exceptions.ZookeeperError` if the server + returns a non-zero error code. + + .. versionadded:: 0.5 + The `include_data` option. + + """ + return self.get_children_async(path, watch, include_data).get() + + def get_children_async(self, path, watch=None, include_data=False): + """Asynchronously get a list of child nodes of a path. Takes + the same arguments as :meth:`get_children`. + + :rtype: :class:`~kazoo.interfaces.IAsyncResult` + + """ + if not isinstance(path, string_types): + raise TypeError("Invalid type for 'path' (string expected)") + if watch and not callable(watch): + raise TypeError("Invalid type for 'watch' (must be a callable)") + if not isinstance(include_data, bool): + raise TypeError("Invalid type for 'include_data' (bool expected)") + + async_result = self.handler.async_result() + if include_data: + req = GetChildren2(_prefix_root(self.chroot, path), watch) + else: + req = GetChildren(_prefix_root(self.chroot, path), watch) + self._call(req, async_result) + return async_result + + def get_acls(self, path): + """Return the ACL and stat of the node of the given path. + + :param path: Path of the node. + :returns: The ACL array of the given node and its + :class:`~kazoo.protocol.states.ZnodeStat`. + :rtype: tuple of (:class:`~kazoo.security.ACL` list, + :class:`~kazoo.protocol.states.ZnodeStat`) + :raises: + :exc:`~kazoo.exceptions.NoNodeError` if the node doesn't + exist. + + :exc:`~kazoo.exceptions.ZookeeperError` if the server + returns a non-zero error code + + .. versionadded:: 0.5 + + """ + return self.get_acls_async(path).get() + + def get_acls_async(self, path): + """Return the ACL and stat of the node of the given path. Takes + the same arguments as :meth:`get_acls`. + + :rtype: :class:`~kazoo.interfaces.IAsyncResult` + + """ + if not isinstance(path, string_types): + raise TypeError("Invalid type for 'path' (string expected)") + + async_result = self.handler.async_result() + self._call(GetACL(_prefix_root(self.chroot, path)), async_result) + return async_result + + def set_acls(self, path, acls, version=-1): + """Set the ACL for the node of the given path. + + Set the ACL for the node of the given path if such a node + exists and the given version matches the version of the node. + + :param path: Path for the node. + :param acls: List of :class:`~kazoo.security.ACL` objects to + set. + :param version: The expected node version that must match. + :returns: The stat of the node. + :raises: + :exc:`~kazoo.exceptions.BadVersionError` if version doesn't + match. + + :exc:`~kazoo.exceptions.NoNodeError` if the node doesn't + exist. + + :exc:`~kazoo.exceptions.InvalidACLError` if the ACL is + invalid. + + :exc:`~kazoo.exceptions.ZookeeperError` if the server + returns a non-zero error code. + + .. versionadded:: 0.5 + + """ + return self.set_acls_async(path, acls, version).get() + + def set_acls_async(self, path, acls, version=-1): + """Set the ACL for the node of the given path. Takes the same + arguments as :meth:`set_acls`. + + :rtype: :class:`~kazoo.interfaces.IAsyncResult` + + """ + if not isinstance(path, string_types): + raise TypeError("Invalid type for 'path' (string expected)") + if isinstance(acls, ACL) or not isinstance(acls, (tuple, list)): + raise TypeError("Invalid type for 'acl' (acl must be a tuple/list" + " of ACL's") + if not isinstance(version, int): + raise TypeError("Invalid type for 'version' (int expected)") + + async_result = self.handler.async_result() + self._call(SetACL(_prefix_root(self.chroot, path), acls, version), + async_result) + return async_result + + def set(self, path, value, version=-1): + """Set the value of a node. + + If the version of the node being updated is newer than the + supplied version (and the supplied version is not -1), a + BadVersionError will be raised. + + This operation, if successful, will trigger all the watches on + the node of the given path left by :meth:`get` API calls. + + The maximum allowable size of the value is 1 MB. Values larger + than this will cause a ZookeeperError to be raised. + + :param path: Path of node. + :param value: New data value. + :param version: Version of node being updated, or -1. + :returns: Updated :class:`~kazoo.protocol.states.ZnodeStat` of + the node. + + :raises: + :exc:`~kazoo.exceptions.BadVersionError` if version doesn't + match. + + :exc:`~kazoo.exceptions.NoNodeError` if the node doesn't + exist. + + :exc:`~kazoo.exceptions.ZookeeperError` if the provided + value is too large. + + :exc:`~kazoo.exceptions.ZookeeperError` if the server + returns a non-zero error code. + + """ + return self.set_async(path, value, version).get() + + def set_async(self, path, value, version=-1): + """Set the value of a node. Takes the same arguments as + :meth:`set`. + + :rtype: :class:`~kazoo.interfaces.IAsyncResult` + + """ + if not isinstance(path, string_types): + raise TypeError("Invalid type for 'path' (string expected)") + if value is not None and not isinstance(value, bytes_types): + raise TypeError("Invalid type for 'value' (must be a byte string)") + if not isinstance(version, int): + raise TypeError("Invalid type for 'version' (int expected)") + + async_result = self.handler.async_result() + self._call(SetData(_prefix_root(self.chroot, path), value, version), + async_result) + return async_result + + def transaction(self): + """Create and return a :class:`TransactionRequest` object + + Creates a :class:`TransactionRequest` object. A Transaction can + consist of multiple operations which can be committed as a + single atomic unit. Either all of the operations will succeed + or none of them. + + :returns: A TransactionRequest. + :rtype: :class:`TransactionRequest` + + .. versionadded:: 0.6 + Requires Zookeeper 3.4+ + + """ + return TransactionRequest(self) + + def delete(self, path, version=-1, recursive=False): + """Delete a node. + + The call will succeed if such a node exists, and the given + version matches the node's version (if the given version is -1, + the default, it matches any node's versions). + + This operation, if successful, will trigger all the watches on + the node of the given path left by `exists` API calls, and the + watches on the parent node left by `get_children` API calls. + + :param path: Path of node to delete. + :param version: Version of node to delete, or -1 for any. + :param recursive: Recursively delete node and all its children, + defaults to False. + :type recursive: bool + + :raises: + :exc:`~kazoo.exceptions.BadVersionError` if version doesn't + match. + + :exc:`~kazoo.exceptions.NoNodeError` if the node doesn't + exist. + + :exc:`~kazoo.exceptions.NotEmptyError` if the node has + children. + + :exc:`~kazoo.exceptions.ZookeeperError` if the server + returns a non-zero error code. + + """ + if not isinstance(recursive, bool): + raise TypeError("Invalid type for 'recursive' (bool expected)") + if recursive: + return self._delete_recursive(path) + else: + return self.delete_async(path, version).get() + + def delete_async(self, path, version=-1): + """Asynchronously delete a node. Takes the same arguments as + :meth:`delete`, with the exception of `recursive`. + + :rtype: :class:`~kazoo.interfaces.IAsyncResult` + + """ + if not isinstance(path, string_types): + raise TypeError("Invalid type for 'path' (string expected)") + if not isinstance(version, int): + raise TypeError("Invalid type for 'version' (int expected)") + async_result = self.handler.async_result() + self._call(Delete(_prefix_root(self.chroot, path), version), + async_result) + return async_result + + def _delete_recursive(self, path): + try: + children = self.get_children(path) + except NoNodeError: + return True + + if children: + for child in children: + if path == "/": + child_path = path + child + else: + child_path = path + "/" + child + + self._delete_recursive(child_path) + try: + self.delete(path) + except NoNodeError: # pragma: nocover + pass + + def reconfig(self, joining, leaving, new_members, from_config=-1): + """Reconfig a cluster. + + This call will succeed if the cluster was reconfigured accordingly. + + :param joining: a comma separated list of servers being added + (see example for format) (incremental reconfiguration) + :param leaving: a comma separated list of servers being removed + (see example for format) (incremental reconfiguration) + :param new_members: a comma separated list of new membership + (non-incremental reconfiguration) + :param from_config: version of the current configuration (optional - + causes reconfiguration to throw an exception if + configuration is no longer current) + :type from_config: int + :returns: + Tuple (value, :class:`~kazoo.protocol.states.ZnodeStat`) of + node. + :rtype: tuple + + Basic Example: + + .. code-block:: python + + zk = KazooClient() + zk.start() + + # first add an observer (incremental reconfiguration) + joining = 'server.100=10.0.0.10:2889:3888:observer;0.0.0.0:2181' + data, _ = zk.reconfig( + joining=joining, leaving=None, new_members=None) + + # wait and then remove it (just by using its id) (incremental) + data, _ = zk.reconfig(joining=None, leaving='100', new_members=None) + + # now do a full change of the cluster (non-incremental) + new = [ + 'server.100=10.0.0.10:2889:3888:observer;0.0.0.0:2181', + 'server.100=10.0.0.11:2889:3888:observer;0.0.0.0:2181', + 'server.100=10.0.0.12:2889:3888:observer;0.0.0.0:2181', + ] + data, _ = zk.reconfig( + joining=None, leaving=None, new_members=','.join(new)) + + zk.stop() + + :raises: + :exc:`~kazoo.exceptions.UnimplementedError` if not supported. + + :exc:`~kazoo.exceptions.NewConfigNoQuorumError` if no quorum of new + config is connected and up-to-date with the leader of last + commmitted config - try invoking reconfiguration after new servers + are connected and synced. + + :exc:`~kazoo.exceptions.ReconfigInProcessError` if another + reconfiguration is in progress. + + :exc:`~kazoo.exceptions.BadVersionError` if version doesn't + match. + + :exc:`~kazoo.exceptions.BadArgumentsError` if any of the given + lists of servers has a bad format. + + :exc:`~kazoo.exceptions.ZookeeperError` if the server + returns a non-zero error code. + + """ + result = self.reconfig_async(joining, leaving, new_members, from_config) + return result.get() + + def reconfig_async(self, joining, leaving, new_members, from_config): + """Asynchronously reconfig a cluster. Takes the same arguments as + :meth:`reconfig`. + + :rtype: :class:`~kazoo.interfaces.IAsyncResult` + + """ + if joining and not isinstance(joining, string_types): + raise TypeError("Invalid type for 'joining' (string expected)") + if leaving and not isinstance(leaving, string_types): + raise TypeError("Invalid type for 'leaving' (string expected)") + if new_members and not isinstance(new_members, string_types): + raise TypeError("Invalid type for 'new_members' (string " + "expected)") + if not isinstance(from_config, int): + raise TypeError("Invalid type for 'from_config' (int expected)") + + async_result = self.handler.async_result() + reconfig = Reconfig(joining, leaving, new_members, from_config) + self._call(reconfig, async_result) + + return async_result + + +class TransactionRequest(object): + """A Zookeeper Transaction Request + + A Transaction provides a builder object that can be used to + construct and commit an atomic set of operations. The transaction + must be committed before its sent. + + Transactions are not thread-safe and should not be accessed from + multiple threads at once. + + .. note:: + + The ``committed`` attribute only indicates whether this + transaction has been sent to Zookeeper and is used to prevent + duplicate commits of the same transaction. The result should be + checked to determine if the transaction executed as desired. + + .. versionadded:: 0.6 + Requires Zookeeper 3.4+ + + """ + def __init__(self, client): + self.client = client + self.operations = [] + self.committed = False + + def create(self, path, value=b"", acl=None, ephemeral=False, + sequence=False): + """Add a create ZNode to the transaction. Takes the same + arguments as :meth:`KazooClient.create`, with the exception + of `makepath`. + + :returns: None + + """ + if acl is None and self.client.default_acl: + acl = self.client.default_acl + + if not isinstance(path, string_types): + raise TypeError("Invalid type for 'path' (string expected)") + if acl and not isinstance(acl, (tuple, list)): + raise TypeError("Invalid type for 'acl' (acl must be a tuple/list" + " of ACL's") + if not isinstance(value, bytes_types): + raise TypeError("Invalid type for 'value' (must be a byte string)") + if not isinstance(ephemeral, bool): + raise TypeError("Invalid type for 'ephemeral' (bool expected)") + if not isinstance(sequence, bool): + raise TypeError("Invalid type for 'sequence' (bool expected)") + + flags = 0 + if ephemeral: + flags |= 1 + if sequence: + flags |= 2 + if acl is None: + acl = OPEN_ACL_UNSAFE + + self._add(Create(_prefix_root(self.client.chroot, path), value, acl, + flags), None) + + def delete(self, path, version=-1): + """Add a delete ZNode to the transaction. Takes the same + arguments as :meth:`KazooClient.delete`, with the exception of + `recursive`. + + """ + if not isinstance(path, string_types): + raise TypeError("Invalid type for 'path' (string expected)") + if not isinstance(version, int): + raise TypeError("Invalid type for 'version' (int expected)") + self._add(Delete(_prefix_root(self.client.chroot, path), version)) + + def set_data(self, path, value, version=-1): + """Add a set ZNode value to the transaction. Takes the same + arguments as :meth:`KazooClient.set`. + + """ + if not isinstance(path, string_types): + raise TypeError("Invalid type for 'path' (string expected)") + if not isinstance(value, bytes_types): + raise TypeError("Invalid type for 'value' (must be a byte string)") + if not isinstance(version, int): + raise TypeError("Invalid type for 'version' (int expected)") + self._add(SetData(_prefix_root(self.client.chroot, path), value, + version)) + + def check(self, path, version): + """Add a Check Version to the transaction. + + This command will fail and abort a transaction if the path + does not match the specified version. + + """ + if not isinstance(path, string_types): + raise TypeError("Invalid type for 'path' (string expected)") + if not isinstance(version, int): + raise TypeError("Invalid type for 'version' (int expected)") + self._add(CheckVersion(_prefix_root(self.client.chroot, path), + version)) + + def commit_async(self): + """Commit the transaction asynchronously. + + :rtype: :class:`~kazoo.interfaces.IAsyncResult` + + """ + self._check_tx_state() + self.committed = True + async_object = self.client.handler.async_result() + self.client._call(Transaction(self.operations), async_object) + return async_object + + def commit(self): + """Commit the transaction. + + :returns: A list of the results for each operation in the + transaction. + + """ + return self.commit_async().get() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_tb): + """Commit and cleanup accumulated transaction data.""" + if not exc_type: + self.commit() + + def _check_tx_state(self): + if self.committed: + raise ValueError('Transaction already committed') + + def _add(self, request, post_processor=None): + self._check_tx_state() + self.client.logger.log(BLATHER, 'Added %r to %r', request, self) + self.operations.append(request) diff --git a/yarn/src/main/python/task-starter/kazoo/exceptions.py b/yarn/src/main/python/task-starter/kazoo/exceptions.py new file mode 100644 index 00000000..6f32b4f3 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/exceptions.py @@ -0,0 +1,209 @@ +"""Kazoo Exceptions""" +from collections import defaultdict + + +class KazooException(Exception): + """Base Kazoo exception that all other kazoo library exceptions + inherit from""" + + +class ZookeeperError(KazooException): + """Base Zookeeper exception for errors originating from the + Zookeeper server""" + + +class CancelledError(KazooException): + """Raised when a process is cancelled by another thread""" + + +class ConfigurationError(KazooException): + """Raised if the configuration arguments to an object are + invalid""" + + +class ZookeeperStoppedError(KazooException): + """Raised when the kazoo client stopped (and thus not connected)""" + + +class ConnectionDropped(KazooException): + """Internal error for jumping out of loops""" + + +class LockTimeout(KazooException): + """Raised if failed to acquire a lock. + + .. versionadded:: 1.1 + """ + + +class WriterNotClosedException(KazooException): + """Raised if the writer is unable to stop closing when requested. + + .. versionadded:: 1.2 + """ + + +def _invalid_error_code(): + raise RuntimeError('Invalid error code') + + +EXCEPTIONS = defaultdict(_invalid_error_code) + + +def _zookeeper_exception(code): + def decorator(klass): + def create(*args, **kwargs): + return klass(args, kwargs) + + EXCEPTIONS[code] = create + klass.code = code + return klass + + return decorator + + +@_zookeeper_exception(0) +class RolledBackError(ZookeeperError): + pass + + +@_zookeeper_exception(-1) +class SystemZookeeperError(ZookeeperError): + pass + + +@_zookeeper_exception(-2) +class RuntimeInconsistency(ZookeeperError): + pass + + +@_zookeeper_exception(-3) +class DataInconsistency(ZookeeperError): + pass + + +@_zookeeper_exception(-4) +class ConnectionLoss(ZookeeperError): + pass + + +@_zookeeper_exception(-5) +class MarshallingError(ZookeeperError): + pass + + +@_zookeeper_exception(-6) +class UnimplementedError(ZookeeperError): + pass + + +@_zookeeper_exception(-7) +class OperationTimeoutError(ZookeeperError): + pass + + +@_zookeeper_exception(-8) +class BadArgumentsError(ZookeeperError): + pass + + +@_zookeeper_exception(-13) +class NewConfigNoQuorumError(ZookeeperError): + pass + + +@_zookeeper_exception(-14) +class ReconfigInProcessError(ZookeeperError): + pass + + +@_zookeeper_exception(-100) +class APIError(ZookeeperError): + pass + + +@_zookeeper_exception(-101) +class NoNodeError(ZookeeperError): + pass + + +@_zookeeper_exception(-102) +class NoAuthError(ZookeeperError): + pass + + +@_zookeeper_exception(-103) +class BadVersionError(ZookeeperError): + pass + + +@_zookeeper_exception(-108) +class NoChildrenForEphemeralsError(ZookeeperError): + pass + + +@_zookeeper_exception(-110) +class NodeExistsError(ZookeeperError): + pass + + +@_zookeeper_exception(-111) +class NotEmptyError(ZookeeperError): + pass + + +@_zookeeper_exception(-112) +class SessionExpiredError(ZookeeperError): + pass + + +@_zookeeper_exception(-113) +class InvalidCallbackError(ZookeeperError): + pass + + +@_zookeeper_exception(-114) +class InvalidACLError(ZookeeperError): + pass + + +@_zookeeper_exception(-115) +class AuthFailedError(ZookeeperError): + pass + + +@_zookeeper_exception(-118) +class SessionMovedError(ZookeeperError): + pass + + +@_zookeeper_exception(-119) +class NotReadOnlyCallError(ZookeeperError): + """An API call that is not read-only was used while connected to + a read-only server""" + + +class ConnectionClosedError(SessionExpiredError): + """Connection is closed""" + + +# BW Compat aliases for C lib style exceptions +ConnectionLossException = ConnectionLoss +MarshallingErrorException = MarshallingError +SystemErrorException = SystemZookeeperError +RuntimeInconsistencyException = RuntimeInconsistency +DataInconsistencyException = DataInconsistency +UnimplementedException = UnimplementedError +OperationTimeoutException = OperationTimeoutError +BadArgumentsException = BadArgumentsError +ApiErrorException = APIError +NoNodeException = NoNodeError +NoAuthException = NoAuthError +BadVersionException = BadVersionError +NoChildrenForEphemeralsException = NoChildrenForEphemeralsError +NodeExistsException = NodeExistsError +InvalidACLException = InvalidACLError +AuthFailedException = AuthFailedError +NotEmptyException = NotEmptyError +SessionExpiredException = SessionExpiredError +InvalidCallbackException = InvalidCallbackError diff --git a/yarn/src/main/python/task-starter/kazoo/handlers/__init__.py b/yarn/src/main/python/task-starter/kazoo/handlers/__init__.py new file mode 100644 index 00000000..792d6005 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/handlers/__init__.py @@ -0,0 +1 @@ +# diff --git a/yarn/src/main/python/task-starter/kazoo/handlers/eventlet.py b/yarn/src/main/python/task-starter/kazoo/handlers/eventlet.py new file mode 100644 index 00000000..dff42f8f --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/handlers/eventlet.py @@ -0,0 +1,173 @@ +"""A eventlet based handler.""" +from __future__ import absolute_import + +import contextlib +import logging + +import kazoo.python2atexit as python2atexit + +import eventlet +from eventlet.green import select as green_select +from eventlet.green import socket as green_socket +from eventlet.green import time as green_time +from eventlet.green import threading as green_threading +from eventlet import queue as green_queue + +from kazoo.handlers import utils + +LOG = logging.getLogger(__name__) + +# sentinel objects +_STOP = object() + + +@contextlib.contextmanager +def _yield_before_after(): + # Yield to any other co-routines... + # + # See: http://eventlet.net/doc/modules/greenthread.html + # for how this zero sleep is really a cooperative yield to other potential + # co-routines... + eventlet.sleep(0) + try: + yield + finally: + eventlet.sleep(0) + + +class TimeoutError(Exception): + pass + + +class AsyncResult(utils.AsyncResult): + """A one-time event that stores a value or an exception""" + def __init__(self, handler): + super(AsyncResult, self).__init__(handler, + green_threading.Condition, + TimeoutError) + + +class SequentialEventletHandler(object): + """Eventlet handler for sequentially executing callbacks. + + This handler executes callbacks in a sequential manner. A queue is + created for each of the callback events, so that each type of event + has its callback type run sequentially. These are split into two + queues, one for watch events and one for async result completion + callbacks. + + Each queue type has a greenthread worker that pulls the callback event + off the queue and runs it in the order the client sees it. + + This split helps ensure that watch callbacks won't block session + re-establishment should the connection be lost during a Zookeeper + client call. + + Watch and completion callbacks should avoid blocking behavior as + the next callback of that type won't be run until it completes. If + you need to block, spawn a new greenthread and return immediately so + callbacks can proceed. + + .. note:: + + Completion callbacks can block to wait on Zookeeper calls, but + no other completion callbacks will execute until the callback + returns. + + """ + name = "sequential_eventlet_handler" + + def __init__(self): + """Create a :class:`SequentialEventletHandler` instance""" + self.callback_queue = green_queue.LightQueue() + self.completion_queue = green_queue.LightQueue() + self._workers = [] + self._started = False + + @staticmethod + def sleep_func(wait): + green_time.sleep(wait) + + @property + def running(self): + return self._started + + timeout_exception = TimeoutError + + def _process_completion_queue(self): + while True: + cb = self.completion_queue.get() + if cb is _STOP: + break + try: + with _yield_before_after(): + cb() + except Exception: + LOG.warning("Exception in worker completion queue greenlet", + exc_info=True) + + def _process_callback_queue(self): + while True: + cb = self.callback_queue.get() + if cb is _STOP: + break + try: + with _yield_before_after(): + cb() + except Exception: + LOG.warning("Exception in worker callback queue greenlet", + exc_info=True) + + def start(self): + if not self._started: + # Spawn our worker threads, we have + # - A callback worker for watch events to be called + # - A completion worker for completion events to be called + w = eventlet.spawn(self._process_completion_queue) + self._workers.append((w, self.completion_queue)) + w = eventlet.spawn(self._process_callback_queue) + self._workers.append((w, self.callback_queue)) + self._started = True + python2atexit.register(self.stop) + + def stop(self): + while self._workers: + w, q = self._workers.pop() + q.put(_STOP) + w.wait() + self._started = False + python2atexit.unregister(self.stop) + + def socket(self, *args, **kwargs): + return utils.create_tcp_socket(green_socket) + + def create_socket_pair(self): + return utils.create_socket_pair(green_socket) + + def event_object(self): + return green_threading.Event() + + def lock_object(self): + return green_threading.Lock() + + def rlock_object(self): + return green_threading.RLock() + + def create_connection(self, *args, **kwargs): + return utils.create_tcp_connection(green_socket, *args, **kwargs) + + def select(self, *args, **kwargs): + with _yield_before_after(): + return green_select.select(*args, **kwargs) + + def async_result(self): + return AsyncResult(self) + + def spawn(self, func, *args, **kwargs): + t = green_threading.Thread(target=func, args=args, kwargs=kwargs) + t.daemon = True + t.start() + return t + + def dispatch_callback(self, callback): + self.callback_queue.put(lambda: callback.func(*callback.args)) diff --git a/yarn/src/main/python/task-starter/kazoo/handlers/gevent.py b/yarn/src/main/python/task-starter/kazoo/handlers/gevent.py new file mode 100644 index 00000000..be8b8e96 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/handlers/gevent.py @@ -0,0 +1,163 @@ +"""A gevent based handler.""" +from __future__ import absolute_import + +import logging + +import gevent +import gevent.event +import gevent.queue +import gevent.select +import gevent.thread + +from gevent.queue import Empty +from gevent.queue import Queue +from gevent import socket +try: + from gevent.lock import Semaphore, RLock +except ImportError: + from gevent.coros import Semaphore, RLock + +from kazoo.handlers import utils +from kazoo import python2atexit + +_using_libevent = gevent.__version__.startswith('0.') + +log = logging.getLogger(__name__) + +_STOP = object() + +AsyncResult = gevent.event.AsyncResult + + +class SequentialGeventHandler(object): + """Gevent handler for sequentially executing callbacks. + + This handler executes callbacks in a sequential manner. A queue is + created for each of the callback events, so that each type of event + has its callback type run sequentially. + + Each queue type has a greenlet worker that pulls the callback event + off the queue and runs it in the order the client sees it. + + This split helps ensure that watch callbacks won't block session + re-establishment should the connection be lost during a Zookeeper + client call. + + Watch callbacks should avoid blocking behavior as the next callback + of that type won't be run until it completes. If you need to block, + spawn a new greenlet and return immediately so callbacks can + proceed. + + """ + name = "sequential_gevent_handler" + sleep_func = staticmethod(gevent.sleep) + + def __init__(self): + """Create a :class:`SequentialGeventHandler` instance""" + self.callback_queue = Queue() + self._running = False + self._async = None + self._state_change = Semaphore() + self._workers = [] + + class timeout_exception(gevent.event.Timeout): + def __init__(self, msg): + gevent.event.Timeout.__init__(self, exception=msg) + + def _create_greenlet_worker(self, queue): + def greenlet_worker(): + while True: + try: + func = queue.get() + if func is _STOP: + break + func() + except Empty: + continue + except Exception as exc: + log.warning("Exception in worker greenlet") + log.exception(exc) + return gevent.spawn(greenlet_worker) + + def start(self): + """Start the greenlet workers.""" + with self._state_change: + if self._running: + return + + self._running = True + + # Spawn our worker greenlets, we have + # - A callback worker for watch events to be called + for queue in (self.callback_queue,): + w = self._create_greenlet_worker(queue) + self._workers.append(w) + python2atexit.register(self.stop) + + def stop(self): + """Stop the greenlet workers and empty all queues.""" + with self._state_change: + if not self._running: + return + + self._running = False + + for queue in (self.callback_queue,): + queue.put(_STOP) + + while self._workers: + worker = self._workers.pop() + worker.join() + + # Clear the queues + self.callback_queue = Queue() # pragma: nocover + + python2atexit.unregister(self.stop) + + def select(self, *args, **kwargs): + return gevent.select.select(*args, **kwargs) + + def socket(self, *args, **kwargs): + return utils.create_tcp_socket(socket) + + def create_connection(self, *args, **kwargs): + return utils.create_tcp_connection(socket, *args, **kwargs) + + def create_socket_pair(self): + return utils.create_socket_pair(socket) + + def event_object(self): + """Create an appropriate Event object""" + return gevent.event.Event() + + def lock_object(self): + """Create an appropriate Lock object""" + return gevent.thread.allocate_lock() + + def rlock_object(self): + """Create an appropriate RLock object""" + return RLock() + + def async_result(self): + """Create a :class:`AsyncResult` instance + + The :class:`AsyncResult` instance will have its completion + callbacks executed in the thread the + :class:`SequentialGeventHandler` is created in (which should be + the gevent/main thread). + + """ + return AsyncResult() + + def spawn(self, func, *args, **kwargs): + """Spawn a function to run asynchronously""" + return gevent.spawn(func, *args, **kwargs) + + def dispatch_callback(self, callback): + """Dispatch to the callback object + + The callback is put on separate queues to run depending on the + type as documented for the :class:`SequentialGeventHandler`. + + """ + self.callback_queue.put(lambda: callback.func(*callback.args)) diff --git a/yarn/src/main/python/task-starter/kazoo/handlers/threading.py b/yarn/src/main/python/task-starter/kazoo/handlers/threading.py new file mode 100644 index 00000000..1eaac649 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/handlers/threading.py @@ -0,0 +1,196 @@ +"""A threading based handler. + +The :class:`SequentialThreadingHandler` is intended for regular Python +environments that use threads. + +.. warning:: + + Do not use :class:`SequentialThreadingHandler` with applications + using asynchronous event loops (like gevent). Use the + :class:`~kazoo.handlers.gevent.SequentialGeventHandler` instead. + +""" +from __future__ import absolute_import + +import errno +import logging +import select +import socket +import threading +import time + +import kazoo.python2atexit as python2atexit + +try: + import Queue +except ImportError: # pragma: nocover + import queue as Queue + +from kazoo.handlers import utils + +# sentinel objects +_STOP = object() + +log = logging.getLogger(__name__) + + +class KazooTimeoutError(Exception): + pass + + +class AsyncResult(utils.AsyncResult): + """A one-time event that stores a value or an exception""" + def __init__(self, handler): + super(AsyncResult, self).__init__(handler, + threading.Condition, + KazooTimeoutError) + + +class SequentialThreadingHandler(object): + """Threading handler for sequentially executing callbacks. + + This handler executes callbacks in a sequential manner. A queue is + created for each of the callback events, so that each type of event + has its callback type run sequentially. These are split into two + queues, one for watch events and one for async result completion + callbacks. + + Each queue type has a thread worker that pulls the callback event + off the queue and runs it in the order the client sees it. + + This split helps ensure that watch callbacks won't block session + re-establishment should the connection be lost during a Zookeeper + client call. + + Watch and completion callbacks should avoid blocking behavior as + the next callback of that type won't be run until it completes. If + you need to block, spawn a new thread and return immediately so + callbacks can proceed. + + .. note:: + + Completion callbacks can block to wait on Zookeeper calls, but + no other completion callbacks will execute until the callback + returns. + + """ + name = "sequential_threading_handler" + timeout_exception = KazooTimeoutError + sleep_func = staticmethod(time.sleep) + queue_impl = Queue.Queue + queue_empty = Queue.Empty + + def __init__(self): + """Create a :class:`SequentialThreadingHandler` instance""" + self.callback_queue = self.queue_impl() + self.completion_queue = self.queue_impl() + self._running = False + self._state_change = threading.Lock() + self._workers = [] + + def _create_thread_worker(self, queue): + def _thread_worker(): # pragma: nocover + while True: + try: + func = queue.get() + try: + if func is _STOP: + break + func() + except Exception: + log.exception("Exception in worker queue thread") + finally: + queue.task_done() + except self.queue_empty: + continue + t = self.spawn(_thread_worker) + return t + + def start(self): + """Start the worker threads.""" + with self._state_change: + if self._running: + return + + # Spawn our worker threads, we have + # - A callback worker for watch events to be called + # - A completion worker for completion events to be called + for queue in (self.completion_queue, self.callback_queue): + w = self._create_thread_worker(queue) + self._workers.append(w) + self._running = True + python2atexit.register(self.stop) + + def stop(self): + """Stop the worker threads and empty all queues.""" + with self._state_change: + if not self._running: + return + + self._running = False + + for queue in (self.completion_queue, self.callback_queue): + queue.put(_STOP) + + self._workers.reverse() + while self._workers: + worker = self._workers.pop() + worker.join() + + # Clear the queues + self.callback_queue = self.queue_impl() + self.completion_queue = self.queue_impl() + python2atexit.unregister(self.stop) + + def select(self, *args, **kwargs): + try: + return select.select(*args, **kwargs) + except select.error as ex: + # if the system call was interrupted, we'll return as a timeout + # in Python 3, system call interruptions are a native exception + # in Python 2, they are not + errnum = ex.errno if isinstance(ex, OSError) else ex[0] + # to mimic a timeout, we return the same thing select would + if errnum == errno.EINTR: + return ([], [], []) + raise + + def socket(self): + return utils.create_tcp_socket(socket) + + def create_connection(self, *args, **kwargs): + return utils.create_tcp_connection(socket, *args, **kwargs) + + def create_socket_pair(self): + return utils.create_socket_pair(socket) + + def event_object(self): + """Create an appropriate Event object""" + return threading.Event() + + def lock_object(self): + """Create a lock object""" + return threading.Lock() + + def rlock_object(self): + """Create an appropriate RLock object""" + return threading.RLock() + + def async_result(self): + """Create a :class:`AsyncResult` instance""" + return AsyncResult(self) + + def spawn(self, func, *args, **kwargs): + t = threading.Thread(target=func, args=args, kwargs=kwargs) + t.daemon = True + t.start() + return t + + def dispatch_callback(self, callback): + """Dispatch to the callback object + + The callback is put on separate queues to run depending on the + type as documented for the :class:`SequentialThreadingHandler`. + + """ + self.callback_queue.put(lambda: callback.func(*callback.args)) diff --git a/yarn/src/main/python/task-starter/kazoo/handlers/utils.py b/yarn/src/main/python/task-starter/kazoo/handlers/utils.py new file mode 100644 index 00000000..6270be44 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/handlers/utils.py @@ -0,0 +1,229 @@ +"""Kazoo handler helpers""" + +import errno +import functools +import select + +HAS_FNCTL = True +try: + import fcntl +except ImportError: # pragma: nocover + HAS_FNCTL = False + +# sentinel objects +_NONE = object() + + +class AsyncResult(object): + """A one-time event that stores a value or an exception""" + def __init__(self, handler, condition_factory, timeout_factory): + self._handler = handler + self._exception = _NONE + self._condition = condition_factory() + self._callbacks = [] + self._timeout_factory = timeout_factory + self.value = None + + def ready(self): + """Return true if and only if it holds a value or an + exception""" + return self._exception is not _NONE + + def successful(self): + """Return true if and only if it is ready and holds a value""" + return self._exception is None + + @property + def exception(self): + if self._exception is not _NONE: + return self._exception + + def set(self, value=None): + """Store the value. Wake up the waiters.""" + with self._condition: + self.value = value + self._exception = None + for callback in self._callbacks: + self._handler.completion_queue.put( + lambda: callback(self) + ) + self._condition.notify_all() + + def set_exception(self, exception): + """Store the exception. Wake up the waiters.""" + with self._condition: + self._exception = exception + for callback in self._callbacks: + self._handler.completion_queue.put( + lambda: callback(self) + ) + self._condition.notify_all() + + def get(self, block=True, timeout=None): + """Return the stored value or raise the exception. + + If there is no value raises TimeoutError. + + """ + with self._condition: + if self._exception is not _NONE: + if self._exception is None: + return self.value + raise self._exception + elif block: + self._condition.wait(timeout) + if self._exception is not _NONE: + if self._exception is None: + return self.value + raise self._exception + + # if we get to this point we timeout + raise self._timeout_factory() + + def get_nowait(self): + """Return the value or raise the exception without blocking. + + If nothing is available, raises TimeoutError + + """ + return self.get(block=False) + + def wait(self, timeout=None): + """Block until the instance is ready.""" + with self._condition: + self._condition.wait(timeout) + return self._exception is not _NONE + + def rawlink(self, callback): + """Register a callback to call when a value or an exception is + set""" + with self._condition: + # Are we already set? Dispatch it now + if self.ready(): + self._handler.completion_queue.put( + lambda: callback(self) + ) + return + + if callback not in self._callbacks: + self._callbacks.append(callback) + + def unlink(self, callback): + """Remove the callback set by :meth:`rawlink`""" + with self._condition: + if self.ready(): + # Already triggered, ignore + return + + if callback in self._callbacks: + self._callbacks.remove(callback) + + +def _set_fd_cloexec(fd): + flags = fcntl.fcntl(fd, fcntl.F_GETFD) + fcntl.fcntl(fd, fcntl.F_SETFD, flags | fcntl.FD_CLOEXEC) + + +def _set_default_tcpsock_options(module, sock): + sock.setsockopt(module.IPPROTO_TCP, module.TCP_NODELAY, 1) + if HAS_FNCTL: + _set_fd_cloexec(sock) + return sock + + +def create_socket_pair(module, port=0): + """Create socket pair. + + If socket.socketpair isn't available, we emulate it. + """ + # See if socketpair() is available. + have_socketpair = hasattr(module, 'socketpair') + if have_socketpair: + client_sock, srv_sock = module.socketpair() + return client_sock, srv_sock + + # Create a non-blocking temporary server socket + temp_srv_sock = module.socket() + temp_srv_sock.setblocking(False) + temp_srv_sock.bind(('', port)) + port = temp_srv_sock.getsockname()[1] + temp_srv_sock.listen(1) + + # Create non-blocking client socket + client_sock = module.socket() + client_sock.setblocking(False) + try: + client_sock.connect(('localhost', port)) + except module.error as err: + # EWOULDBLOCK is not an error, as the socket is non-blocking + if err.errno != errno.EWOULDBLOCK: + raise + + # Use select to wait for connect() to succeed. + timeout = 1 + readable = select.select([temp_srv_sock], [], [], timeout)[0] + if temp_srv_sock not in readable: + raise Exception('Client socket not connected in %s' + ' second(s)' % (timeout)) + srv_sock, _ = temp_srv_sock.accept() + return client_sock, srv_sock + + +def create_tcp_socket(module): + """Create a TCP socket with the CLOEXEC flag set. + """ + type_ = module.SOCK_STREAM + if hasattr(module, 'SOCK_CLOEXEC'): # pragma: nocover + # if available, set cloexec flag during socket creation + type_ |= module.SOCK_CLOEXEC + sock = module.socket(module.AF_INET, type_) + _set_default_tcpsock_options(module, sock) + return sock + + +def create_tcp_connection(module, address, timeout=None): + if timeout is None: + # thanks to create_connection() developers for + # this ugliness... + timeout = module._GLOBAL_DEFAULT_TIMEOUT + + sock = module.create_connection(address, timeout) + _set_default_tcpsock_options(module, sock) + return sock + + +def capture_exceptions(async_result): + """Return a new decorated function that propagates the exceptions of the + wrapped function to an async_result. + + :param async_result: An async result implementing :class:`IAsyncResult` + + """ + def capture(function): + @functools.wraps(function) + def captured_function(*args, **kwargs): + try: + return function(*args, **kwargs) + except Exception as exc: + async_result.set_exception(exc) + return captured_function + return capture + + +def wrap(async_result): + """Return a new decorated function that propagates the return value or + exception of wrapped function to an async_result. NOTE: Only propagates a + non-None return value. + + :param async_result: An async result implementing :class:`IAsyncResult` + + """ + def capture(function): + @capture_exceptions(async_result) + def captured_function(*args, **kwargs): + value = function(*args, **kwargs) + if value is not None: + async_result.set(value) + return value + return captured_function + return capture diff --git a/yarn/src/main/python/task-starter/kazoo/hosts.py b/yarn/src/main/python/task-starter/kazoo/hosts.py new file mode 100644 index 00000000..26b3573f --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/hosts.py @@ -0,0 +1,25 @@ +import random + +from six.moves import urllib_parse + + +def collect_hosts(hosts, randomize=True): + """Collect a set of hosts and an optional chroot from a string.""" + host_ports, chroot = hosts.partition("/")[::2] + chroot = "/" + chroot if chroot else None + + result = [] + for host_port in host_ports.split(","): + # put all complexity of dealing with + # IPv4 & IPv6 address:port on the urlsplit + res = urllib_parse.urlsplit("xxx://" + host_port) + host = res.hostname + if host is None: + raise ValueError("bad hostname") + port = int(res.port) if res.port else 2181 + result.append((host.strip(), port)) + + if randomize: + random.shuffle(result) + + return result, chroot diff --git a/yarn/src/main/python/task-starter/kazoo/interfaces.py b/yarn/src/main/python/task-starter/kazoo/interfaces.py new file mode 100644 index 00000000..351f1fd8 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/interfaces.py @@ -0,0 +1,203 @@ +"""Kazoo Interfaces + +.. versionchanged:: 1.4 + + The classes in this module used to be interface declarations based on + `zope.interface.Interface`. They were converted to normal classes and + now serve as documentation only. + +""" + +# public API + + +class IHandler(object): + """A Callback Handler for Zookeeper completion and watch callbacks. + + This object must implement several methods responsible for + determining how completion / watch callbacks are handled as well as + the method for calling :class:`IAsyncResult` callback functions. + + These functions are used to abstract differences between a Python + threading environment and asynchronous single-threaded environments + like gevent. The minimum functionality needed for Kazoo to handle + these differences is encompassed in this interface. + + The Handler should document how callbacks are called for: + + * Zookeeper completion events + * Zookeeper watch events + + .. attribute:: name + + Human readable name of the Handler interface. + + .. attribute:: timeout_exception + + Exception class that should be thrown and captured if a + result is not available within the given time. + + .. attribute:: sleep_func + + Appropriate sleep function that can be called with a single + argument and sleep. + + """ + + def start(self): + """Start the handler, used for setting up the handler.""" + + def stop(self): + """Stop the handler. Should block until the handler is safely + stopped.""" + + def select(self): + """A select method that implements Python's select.select + API""" + + def socket(self): + """A socket method that implements Python's socket.socket + API""" + + def create_connection(self): + """A socket method that implements Python's + socket.create_connection API""" + + def event_object(self): + """Return an appropriate object that implements Python's + threading.Event API""" + + def lock_object(self): + """Return an appropriate object that implements Python's + threading.Lock API""" + + def rlock_object(self): + """Return an appropriate object that implements Python's + threading.RLock API""" + + def async_result(self): + """Return an instance that conforms to the + :class:`~IAsyncResult` interface appropriate for this + handler""" + + def spawn(self, func, *args, **kwargs): + """Spawn a function to run asynchronously + + :param args: args to call the function with. + :param kwargs: keyword args to call the function with. + + This method should return immediately and execute the function + with the provided args and kwargs in an asynchronous manner. + + """ + + def dispatch_callback(self, callback): + """Dispatch to the callback object + + :param callback: A :class:`~kazoo.protocol.states.Callback` + object to be called. + + """ + + +class IAsyncResult(object): + """An Async Result object that can be queried for a value that has + been set asynchronously. + + This object is modeled on the ``gevent`` AsyncResult object. + + The implementation must account for the fact that the :meth:`set` + and :meth:`set_exception` methods will be called from within the + Zookeeper thread which may require extra care under asynchronous + environments. + + .. attribute:: value + + Holds the value passed to :meth:`set` if :meth:`set` was + called. Otherwise `None`. + + .. attribute:: exception + + Holds the exception instance passed to :meth:`set_exception` + if :meth:`set_exception` was called. Otherwise `None`. + + """ + + def ready(self): + """Return `True` if and only if it holds a value or an + exception""" + + def successful(self): + """Return `True` if and only if it is ready and holds a + value""" + + def set(self, value=None): + """Store the value. Wake up the waiters. + + :param value: Value to store as the result. + + Any waiters blocking on :meth:`get` or :meth:`wait` are woken + up. Sequential calls to :meth:`wait` and :meth:`get` will not + block at all.""" + + def set_exception(self, exception): + """Store the exception. Wake up the waiters. + + :param exception: Exception to raise when fetching the value. + + Any waiters blocking on :meth:`get` or :meth:`wait` are woken + up. Sequential calls to :meth:`wait` and :meth:`get` will not + block at all.""" + + def get(self, block=True, timeout=None): + """Return the stored value or raise the exception + + :param block: Whether this method should block or return + immediately. + :type block: bool + :param timeout: How long to wait for a value when `block` is + `True`. + :type timeout: float + + If this instance already holds a value / an exception, return / + raise it immediately. Otherwise, block until :meth:`set` or + :meth:`set_exception` has been called or until the optional + timeout occurs.""" + + def get_nowait(self): + """Return the value or raise the exception without blocking. + + If nothing is available, raise the Timeout exception class on + the associated :class:`IHandler` interface.""" + + def wait(self, timeout=None): + """Block until the instance is ready. + + :param timeout: How long to wait for a value when `block` is + `True`. + :type timeout: float + + If this instance already holds a value / an exception, return / + raise it immediately. Otherwise, block until :meth:`set` or + :meth:`set_exception` has been called or until the optional + timeout occurs.""" + + def rawlink(self, callback): + """Register a callback to call when a value or an exception is + set + + :param callback: + A callback function to call after :meth:`set` or + :meth:`set_exception` has been called. This function will + be passed a single argument, this instance. + :type callback: func + + """ + + def unlink(self, callback): + """Remove the callback set by :meth:`rawlink` + + :param callback: A callback function to remove. + :type callback: func + + """ diff --git a/yarn/src/main/python/task-starter/kazoo/loggingsupport.py b/yarn/src/main/python/task-starter/kazoo/loggingsupport.py new file mode 100644 index 00000000..bc12c7c8 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/loggingsupport.py @@ -0,0 +1 @@ +BLATHER = 5 # log level for low-level debugging diff --git a/yarn/src/main/python/task-starter/kazoo/protocol/__init__.py b/yarn/src/main/python/task-starter/kazoo/protocol/__init__.py new file mode 100644 index 00000000..792d6005 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/protocol/__init__.py @@ -0,0 +1 @@ +# diff --git a/yarn/src/main/python/task-starter/kazoo/protocol/connection.py b/yarn/src/main/python/task-starter/kazoo/protocol/connection.py new file mode 100644 index 00000000..067a6295 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/protocol/connection.py @@ -0,0 +1,630 @@ +"""Zookeeper Protocol Connection Handler""" +import logging +import random +import select +import socket + +import sys +import time +from binascii import hexlify +from contextlib import contextmanager + +from kazoo.exceptions import ( + AuthFailedError, + ConnectionDropped, + EXCEPTIONS, + SessionExpiredError, + NoNodeError +) +from kazoo.loggingsupport import BLATHER +from kazoo.protocol.serialization import ( + Auth, + Close, + Connect, + Exists, + GetChildren, + Ping, + PingInstance, + ReplyHeader, + Transaction, + Watch, + int_struct +) +from kazoo.protocol.states import ( + Callback, + KeeperState, + WatchedEvent, + EVENT_TYPE_MAP, +) +from kazoo.retry import ( + ForceRetryError, + RetryFailedError +) + +log = logging.getLogger(__name__) + + +# Special testing hook objects used to force a session expired error as +# if it came from the server +_SESSION_EXPIRED = object() +_CONNECTION_DROP = object() + +STOP_CONNECTING = object() + +CREATED_EVENT = 1 +DELETED_EVENT = 2 +CHANGED_EVENT = 3 +CHILD_EVENT = 4 + +WATCH_XID = -1 +PING_XID = -2 +AUTH_XID = -4 + +CLOSE_RESPONSE = Close.type + +if sys.version_info > (3, ): # pragma: nocover + def buffer(obj, offset=0): + return memoryview(obj)[offset:] + + advance_iterator = next +else: # pragma: nocover + def advance_iterator(it): + return it.next() + + +class RWPinger(object): + """A Read/Write Server Pinger Iterable + + This object is initialized with the hosts iterator object and the + socket creation function. Anytime `next` is called on its iterator + it yields either False, or a host, port tuple if it found a r/w + capable Zookeeper node. + + After the first run-through of hosts, an exponential back-off delay + is added before the next run. This delay is tracked internally and + the iterator will yield False if called too soon. + + """ + def __init__(self, hosts, connection_func, socket_handling): + self.hosts = hosts + self.connection = connection_func + self.last_attempt = None + self.socket_handling = socket_handling + + def __iter__(self): + if not self.last_attempt: + self.last_attempt = time.time() + delay = 0.5 + while True: + yield self._next_server(delay) + + def _next_server(self, delay): + jitter = random.randint(0, 100) / 100.0 + while time.time() < self.last_attempt + delay + jitter: + # Skip rw ping checks if its too soon + return False + for host, port in self.hosts: + log.debug("Pinging server for r/w: %s:%s", host, port) + self.last_attempt = time.time() + try: + with self.socket_handling(): + sock = self.connection((host, port)) + sock.sendall(b"isro") + result = sock.recv(8192) + sock.close() + if result == b'rw': + return (host, port) + else: + return False + except ConnectionDropped: + return False + + # Add some jitter between host pings + while time.time() < self.last_attempt + jitter: + return False + delay *= 2 + + +class RWServerAvailable(Exception): + """Thrown if a RW Server becomes available""" + + +class ConnectionHandler(object): + """Zookeeper connection handler""" + def __init__(self, client, retry_sleeper, logger=None): + self.client = client + self.handler = client.handler + self.retry_sleeper = retry_sleeper + self.logger = logger or log + + # Our event objects + self.connection_closed = client.handler.event_object() + self.connection_closed.set() + self.connection_stopped = client.handler.event_object() + self.connection_stopped.set() + self.ping_outstanding = client.handler.event_object() + + self._read_sock = None + self._write_sock = None + + self._socket = None + self._xid = None + self._rw_server = None + self._ro_mode = False + + self._connection_routine = None + + # This is instance specific to avoid odd thread bug issues in Python + # during shutdown global cleanup + @contextmanager + def _socket_error_handling(self): + try: + yield + except (socket.error, select.error) as e: + err = getattr(e, 'strerror', e) + raise ConnectionDropped("socket connection error: %s" % (err,)) + + def start(self): + """Start the connection up""" + if self.connection_closed.is_set(): + rw_sockets = self.handler.create_socket_pair() + self._read_sock, self._write_sock = rw_sockets + self.connection_closed.clear() + if self._connection_routine: + raise Exception("Unable to start, connection routine already " + "active.") + self._connection_routine = self.handler.spawn(self.zk_loop) + + def stop(self, timeout=None): + """Ensure the writer has stopped, wait to see if it does.""" + self.connection_stopped.wait(timeout) + if self._connection_routine: + self._connection_routine.join() + self._connection_routine = None + return self.connection_stopped.is_set() + + def close(self): + """Release resources held by the connection + + The connection can be restarted afterwards. + """ + if not self.connection_stopped.is_set(): + raise Exception("Cannot close connection until it is stopped") + self.connection_closed.set() + ws, rs = self._write_sock, self._read_sock + self._write_sock = self._read_sock = None + if ws is not None: + ws.close() + if rs is not None: + rs.close() + + def _server_pinger(self): + """Returns a server pinger iterable, that will ping the next + server in the list, and apply a back-off between attempts.""" + return RWPinger(self.client.hosts, self.handler.create_connection, + self._socket_error_handling) + + def _read_header(self, timeout): + b = self._read(4, timeout) + length = int_struct.unpack(b)[0] + b = self._read(length, timeout) + header, offset = ReplyHeader.deserialize(b, 0) + return header, b, offset + + def _read(self, length, timeout): + msgparts = [] + remaining = length + with self._socket_error_handling(): + while remaining > 0: + s = self.handler.select([self._socket], [], [], timeout)[0] + if not s: # pragma: nocover + # If the read list is empty, we got a timeout. We don't + # have to check wlist and xlist as we don't set any + raise self.handler.timeout_exception("socket time-out" + " during read") + + chunk = self._socket.recv(remaining) + if chunk == b'': + raise ConnectionDropped('socket connection broken') + msgparts.append(chunk) + remaining -= len(chunk) + return b"".join(msgparts) + + def _invoke(self, timeout, request, xid=None): + """A special writer used during connection establishment + only""" + self._submit(request, timeout, xid) + zxid = None + if xid: + header, buffer, offset = self._read_header(timeout) + if header.xid != xid: + raise RuntimeError('xids do not match, expected %r ' + 'received %r', xid, header.xid) + if header.zxid > 0: + zxid = header.zxid + if header.err: + callback_exception = EXCEPTIONS[header.err]() + self.logger.debug( + 'Received error(xid=%s) %r', xid, callback_exception) + raise callback_exception + return zxid + + msg = self._read(4, timeout) + length = int_struct.unpack(msg)[0] + msg = self._read(length, timeout) + + if hasattr(request, 'deserialize'): + try: + obj, _ = request.deserialize(msg, 0) + except Exception: + self.logger.exception( + "Exception raised during deserialization " + "of request: %s", request) + + # raise ConnectionDropped so connect loop will retry + raise ConnectionDropped('invalid server response') + self.logger.log(BLATHER, 'Read response %s', obj) + return obj, zxid + + return zxid + + def _submit(self, request, timeout, xid=None): + """Submit a request object with a timeout value and optional + xid""" + b = bytearray() + if xid: + b.extend(int_struct.pack(xid)) + if request.type: + b.extend(int_struct.pack(request.type)) + b += request.serialize() + self.logger.log( + (BLATHER if isinstance(request, Ping) else logging.DEBUG), + "Sending request(xid=%s): %s", xid, request) + self._write(int_struct.pack(len(b)) + b, timeout) + + def _write(self, msg, timeout): + """Write a raw msg to the socket""" + sent = 0 + msg_length = len(msg) + with self._socket_error_handling(): + while sent < msg_length: + s = self.handler.select([], [self._socket], [], timeout)[1] + if not s: # pragma: nocover + # If the write list is empty, we got a timeout. We don't + # have to check rlist and xlist as we don't set any + raise self.handler.timeout_exception("socket time-out" + " during write") + msg_slice = buffer(msg, sent) + bytes_sent = self._socket.send(msg_slice) + if not bytes_sent: + raise ConnectionDropped('socket connection broken') + sent += bytes_sent + + def _read_watch_event(self, buffer, offset): + client = self.client + watch, offset = Watch.deserialize(buffer, offset) + path = watch.path + + self.logger.debug('Received EVENT: %s', watch) + + watchers = [] + + if watch.type in (CREATED_EVENT, CHANGED_EVENT): + watchers.extend(client._data_watchers.pop(path, [])) + elif watch.type == DELETED_EVENT: + watchers.extend(client._data_watchers.pop(path, [])) + watchers.extend(client._child_watchers.pop(path, [])) + elif watch.type == CHILD_EVENT: + watchers.extend(client._child_watchers.pop(path, [])) + else: + self.logger.warn('Received unknown event %r', watch.type) + return + + # Strip the chroot if needed + path = client.unchroot(path) + ev = WatchedEvent(EVENT_TYPE_MAP[watch.type], client._state, path) + + # Last check to ignore watches if we've been stopped + if client._stopped.is_set(): + return + + # Dump the watchers to the watch thread + for watch in watchers: + client.handler.dispatch_callback(Callback('watch', watch, (ev,))) + + def _read_response(self, header, buffer, offset): + client = self.client + request, async_object, xid = client._pending.popleft() + if header.zxid and header.zxid > 0: + client.last_zxid = header.zxid + if header.xid != xid: + raise RuntimeError('xids do not match, expected %r ' + 'received %r', xid, header.xid) + + # Determine if its an exists request and a no node error + exists_error = (header.err == NoNodeError.code and + request.type == Exists.type) + + # Set the exception if its not an exists error + if header.err and not exists_error: + callback_exception = EXCEPTIONS[header.err]() + self.logger.debug( + 'Received error(xid=%s) %r', xid, callback_exception) + if async_object: + async_object.set_exception(callback_exception) + elif request and async_object: + if exists_error: + # It's a NoNodeError, which is fine for an exists + # request + async_object.set(None) + else: + try: + response = request.deserialize(buffer, offset) + except Exception as exc: + self.logger.exception( + "Exception raised during deserialization " + "of request: %s", request) + async_object.set_exception(exc) + return + self.logger.debug( + 'Received response(xid=%s): %r', xid, response) + + # We special case a Transaction as we have to unchroot things + if request.type == Transaction.type: + response = Transaction.unchroot(client, response) + + async_object.set(response) + + # Determine if watchers should be registered + watcher = getattr(request, 'watcher', None) + if not client._stopped.is_set() and watcher: + if isinstance(request, GetChildren): + client._child_watchers[request.path].add(watcher) + else: + client._data_watchers[request.path].add(watcher) + + if isinstance(request, Close): + self.logger.log(BLATHER, 'Read close response') + return CLOSE_RESPONSE + + def _read_socket(self, read_timeout): + """Called when there's something to read on the socket""" + client = self.client + + header, buffer, offset = self._read_header(read_timeout) + if header.xid == PING_XID: + self.logger.log(BLATHER, 'Received Ping') + self.ping_outstanding.clear() + elif header.xid == AUTH_XID: + self.logger.log(BLATHER, 'Received AUTH') + + request, async_object, xid = client._pending.popleft() + if header.err: + async_object.set_exception(AuthFailedError()) + client._session_callback(KeeperState.AUTH_FAILED) + else: + async_object.set(True) + elif header.xid == WATCH_XID: + self._read_watch_event(buffer, offset) + else: + self.logger.log(BLATHER, 'Reading for header %r', header) + + return self._read_response(header, buffer, offset) + + def _send_request(self, read_timeout, connect_timeout): + """Called when we have something to send out on the socket""" + client = self.client + try: + request, async_object = client._queue[0] + except IndexError: + # Not actually something on the queue, this can occur if + # something happens to cancel the request such that we + # don't clear the socket below after sending + try: + # Clear possible inconsistence (no request in the queue + # but have data in the read socket), which causes cpu to spin. + self._read_sock.recv(1) + except OSError: + pass + return + + # Special case for testing, if this is a _SessionExpire object + # then throw a SessionExpiration error as if we were dropped + if request is _SESSION_EXPIRED: + raise SessionExpiredError("Session expired: Testing") + if request is _CONNECTION_DROP: + raise ConnectionDropped("Connection dropped: Testing") + + # Special case for auth packets + if request.type == Auth.type: + xid = AUTH_XID + else: + self._xid += 1 + xid = self._xid + + self._submit(request, connect_timeout, xid) + client._queue.popleft() + self._read_sock.recv(1) + client._pending.append((request, async_object, xid)) + + def _send_ping(self, connect_timeout): + self.ping_outstanding.set() + self._submit(PingInstance, connect_timeout, PING_XID) + + # Determine if we need to check for a r/w server + if self._ro_mode: + result = advance_iterator(self._ro_mode) + if result: + self._rw_server = result + raise RWServerAvailable() + + def zk_loop(self): + """Main Zookeeper handling loop""" + self.logger.log(BLATHER, 'ZK loop started') + + self.connection_stopped.clear() + + retry = self.retry_sleeper.copy() + try: + while not self.client._stopped.is_set(): + # If the connect_loop returns STOP_CONNECTING, stop retrying + if retry(self._connect_loop, retry) is STOP_CONNECTING: + break + except RetryFailedError: + self.logger.warning("Failed connecting to Zookeeper " + "within the connection retry policy.") + finally: + self.connection_stopped.set() + self.client._session_callback(KeeperState.CLOSED) + self.logger.log(BLATHER, 'Connection stopped') + + def _connect_loop(self, retry): + # Iterate through the hosts a full cycle before starting over + status = None + for host, port in self.client.hosts: + if self.client._stopped.is_set(): + status = STOP_CONNECTING + break + status = self._connect_attempt(host, port, retry) + if status is STOP_CONNECTING: + break + + if status is STOP_CONNECTING: + return STOP_CONNECTING + else: + raise ForceRetryError('Reconnecting') + + def _connect_attempt(self, host, port, retry): + client = self.client + KazooTimeoutError = self.handler.timeout_exception + close_connection = False + + self._socket = None + + # Were we given a r/w server? If so, use that instead + if self._rw_server: + self.logger.log(BLATHER, + "Found r/w server to use, %s:%s", host, port) + host, port = self._rw_server + self._rw_server = None + + if client._state != KeeperState.CONNECTING: + client._session_callback(KeeperState.CONNECTING) + + try: + read_timeout, connect_timeout = self._connect(host, port) + read_timeout = read_timeout / 1000.0 + connect_timeout = connect_timeout / 1000.0 + retry.reset() + self._xid = 0 + self.ping_outstanding.clear() + with self._socket_error_handling(): + while not close_connection: + # Watch for something to read or send + jitter_time = random.randint(0, 40) / 100.0 + # Ensure our timeout is positive + timeout = max([read_timeout / 2.0 - jitter_time, + jitter_time]) + s = self.handler.select([self._socket, self._read_sock], + [], [], timeout)[0] + + if not s: + if self.ping_outstanding.is_set(): + self.ping_outstanding.clear() + raise ConnectionDropped( + "outstanding heartbeat ping not received") + self._send_ping(connect_timeout) + elif s[0] == self._socket: + response = self._read_socket(read_timeout) + close_connection = response == CLOSE_RESPONSE + else: + self._send_request(read_timeout, connect_timeout) + self.logger.info('Closing connection to %s:%s', host, port) + client._session_callback(KeeperState.CLOSED) + return STOP_CONNECTING + except (ConnectionDropped, KazooTimeoutError) as e: + if isinstance(e, ConnectionDropped): + self.logger.warning('Connection dropped: %s', e) + else: + self.logger.warning('Connection time-out: %s', e) + if client._state != KeeperState.CONNECTING: + self.logger.warning("Transition to CONNECTING") + client._session_callback(KeeperState.CONNECTING) + except AuthFailedError: + retry.reset() + self.logger.warning('AUTH_FAILED closing') + client._session_callback(KeeperState.AUTH_FAILED) + return STOP_CONNECTING + except SessionExpiredError: + retry.reset() + self.logger.warning('Session has expired') + client._session_callback(KeeperState.EXPIRED_SESSION) + except RWServerAvailable: + retry.reset() + self.logger.warning('Found a RW server, dropping connection') + client._session_callback(KeeperState.CONNECTING) + except Exception: + self.logger.exception('Unhandled exception in connection loop') + raise + finally: + if self._socket is not None: + self._socket.close() + + def _connect(self, host, port): + client = self.client + self.logger.info('Connecting to %s:%s', host, port) + + self.logger.log(BLATHER, + ' Using session_id: %r session_passwd: %s', + client._session_id, + hexlify(client._session_passwd)) + + with self._socket_error_handling(): + self._socket = self.handler.create_connection( + (host, port), client._session_timeout / 1000.0) + + self._socket.setblocking(0) + + connect = Connect(0, client.last_zxid, client._session_timeout, + client._session_id or 0, client._session_passwd, + client.read_only) + + connect_result, zxid = self._invoke( + client._session_timeout / 1000.0, connect) + + if connect_result.time_out <= 0: + raise SessionExpiredError("Session has expired") + + if zxid: + client.last_zxid = zxid + + # Load return values + client._session_id = connect_result.session_id + client._protocol_version = connect_result.protocol_version + negotiated_session_timeout = connect_result.time_out + connect_timeout = negotiated_session_timeout / len(client.hosts) + read_timeout = negotiated_session_timeout * 2.0 / 3.0 + client._session_passwd = connect_result.passwd + + self.logger.log(BLATHER, + 'Session created, session_id: %r session_passwd: %s\n' + ' negotiated session timeout: %s\n' + ' connect timeout: %s\n' + ' read timeout: %s', client._session_id, + hexlify(client._session_passwd), + negotiated_session_timeout, connect_timeout, + read_timeout) + + if connect_result.read_only: + client._session_callback(KeeperState.CONNECTED_RO) + self._ro_mode = iter(self._server_pinger()) + else: + client._session_callback(KeeperState.CONNECTED) + self._ro_mode = None + + for scheme, auth in client.auth_data: + ap = Auth(0, scheme, auth) + zxid = self._invoke(connect_timeout / 1000.0, ap, xid=AUTH_XID) + if zxid: + client.last_zxid = zxid + return read_timeout, connect_timeout diff --git a/yarn/src/main/python/task-starter/kazoo/protocol/paths.py b/yarn/src/main/python/task-starter/kazoo/protocol/paths.py new file mode 100644 index 00000000..7fe961c2 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/protocol/paths.py @@ -0,0 +1,55 @@ +def normpath(path, trailing=False): + """Normalize path, eliminating double slashes, etc.""" + comps = path.split('/') + new_comps = [] + for comp in comps: + if comp == '': + continue + if comp in ('.', '..'): + raise ValueError('relative paths not allowed') + new_comps.append(comp) + new_path = '/'.join(new_comps) + if trailing is True and path.endswith('/'): + new_path += '/' + if path.startswith('/') and new_path != '/': + return '/' + new_path + return new_path + + +def join(a, *p): + """Join two or more pathname components, inserting '/' as needed. + + If any component is an absolute path, all previous path components + will be discarded. + + """ + path = a + for b in p: + if b.startswith('/'): + path = b + elif path == '' or path.endswith('/'): + path += b + else: + path += '/' + b + return path + + +def isabs(s): + """Test whether a path is absolute""" + return s.startswith('/') + + +def basename(p): + """Returns the final component of a pathname""" + i = p.rfind('/') + 1 + return p[i:] + + +def _prefix_root(root, path, trailing=False): + """Prepend a root to a path. """ + return normpath(join(_norm_root(root), path.lstrip('/')), + trailing=trailing) + + +def _norm_root(root): + return normpath(join('/', root)) diff --git a/yarn/src/main/python/task-starter/kazoo/protocol/serialization.py b/yarn/src/main/python/task-starter/kazoo/protocol/serialization.py new file mode 100644 index 00000000..88bb9e2e --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/protocol/serialization.py @@ -0,0 +1,415 @@ +"""Zookeeper Serializers, Deserializers, and NamedTuple objects""" +from collections import namedtuple +import struct + +from kazoo.exceptions import EXCEPTIONS +from kazoo.protocol.states import ZnodeStat +from kazoo.security import ACL +from kazoo.security import Id + +# Struct objects with formats compiled +bool_struct = struct.Struct('B') +int_struct = struct.Struct('!i') +long_struct = struct.Struct('!q') +int_int_struct = struct.Struct('!ii') +int_int_long_struct = struct.Struct('!iiq') + +int_long_int_long_struct = struct.Struct('!iqiq') +multiheader_struct = struct.Struct('!iBi') +reply_header_struct = struct.Struct('!iqi') +stat_struct = struct.Struct('!qqqqiiiqiiq') + +try: # pragma: nocover + basestring +except NameError: + basestring = str + + +def read_string(buffer, offset): + """Reads an int specified buffer into a string and returns the + string and the new offset in the buffer""" + length = int_struct.unpack_from(buffer, offset)[0] + offset += int_struct.size + if length < 0: + return None, offset + else: + index = offset + offset += length + return buffer[index:index + length].decode('utf-8'), offset + + +def read_acl(bytes, offset): + perms = int_struct.unpack_from(bytes, offset)[0] + offset += int_struct.size + scheme, offset = read_string(bytes, offset) + id, offset = read_string(bytes, offset) + return ACL(perms, Id(scheme, id)), offset + + +def write_string(bytes): + if not bytes: + return int_struct.pack(-1) + else: + utf8_str = bytes.encode('utf-8') + return int_struct.pack(len(utf8_str)) + utf8_str + + +def write_buffer(bytes): + if bytes is None: + return int_struct.pack(-1) + else: + return int_struct.pack(len(bytes)) + bytes + + +def read_buffer(bytes, offset): + length = int_struct.unpack_from(bytes, offset)[0] + offset += int_struct.size + if length < 0: + return None, offset + else: + index = offset + offset += length + return bytes[index:index + length], offset + + +class Close(namedtuple('Close', '')): + type = -11 + + @classmethod + def serialize(cls): + return b'' + +CloseInstance = Close() + + +class Ping(namedtuple('Ping', '')): + type = 11 + + @classmethod + def serialize(cls): + return b'' + +PingInstance = Ping() + + +class Connect(namedtuple('Connect', 'protocol_version last_zxid_seen' + ' time_out session_id passwd read_only')): + type = None + + def serialize(self): + b = bytearray() + b.extend(int_long_int_long_struct.pack( + self.protocol_version, self.last_zxid_seen, self.time_out, + self.session_id)) + b.extend(write_buffer(self.passwd)) + b.extend([1 if self.read_only else 0]) + return b + + @classmethod + def deserialize(cls, bytes, offset): + proto_version, timeout, session_id = int_int_long_struct.unpack_from( + bytes, offset) + offset += int_int_long_struct.size + password, offset = read_buffer(bytes, offset) + + try: + read_only = bool_struct.unpack_from(bytes, offset)[0] is 1 + offset += bool_struct.size + except struct.error: + read_only = False + return cls(proto_version, 0, timeout, session_id, password, + read_only), offset + + +class Create(namedtuple('Create', 'path data acl flags')): + type = 1 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(write_buffer(self.data)) + b.extend(int_struct.pack(len(self.acl))) + for acl in self.acl: + b.extend(int_struct.pack(acl.perms) + + write_string(acl.id.scheme) + write_string(acl.id.id)) + b.extend(int_struct.pack(self.flags)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + return read_string(bytes, offset)[0] + + +class Delete(namedtuple('Delete', 'path version')): + type = 2 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(int_struct.pack(self.version)) + return b + + @classmethod + def deserialize(self, bytes, offset): + return True + + +class Exists(namedtuple('Exists', 'path watcher')): + type = 3 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend([1 if self.watcher else 0]) + return b + + @classmethod + def deserialize(cls, bytes, offset): + stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + return stat if stat.czxid != -1 else None + + +class GetData(namedtuple('GetData', 'path watcher')): + type = 4 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend([1 if self.watcher else 0]) + return b + + @classmethod + def deserialize(cls, bytes, offset): + data, offset = read_buffer(bytes, offset) + stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + return data, stat + + +class SetData(namedtuple('SetData', 'path data version')): + type = 5 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(write_buffer(self.data)) + b.extend(int_struct.pack(self.version)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + return ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + + +class GetACL(namedtuple('GetACL', 'path')): + type = 6 + + def serialize(self): + return bytearray(write_string(self.path)) + + @classmethod + def deserialize(cls, bytes, offset): + count = int_struct.unpack_from(bytes, offset)[0] + offset += int_struct.size + if count == -1: # pragma: nocover + return [] + + acls = [] + for c in range(count): + acl, offset = read_acl(bytes, offset) + acls.append(acl) + stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + return acls, stat + + +class SetACL(namedtuple('SetACL', 'path acls version')): + type = 7 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(int_struct.pack(len(self.acls))) + for acl in self.acls: + b.extend(int_struct.pack(acl.perms) + + write_string(acl.id.scheme) + write_string(acl.id.id)) + b.extend(int_struct.pack(self.version)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + return ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + + +class GetChildren(namedtuple('GetChildren', 'path watcher')): + type = 8 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend([1 if self.watcher else 0]) + return b + + @classmethod + def deserialize(cls, bytes, offset): + count = int_struct.unpack_from(bytes, offset)[0] + offset += int_struct.size + if count == -1: # pragma: nocover + return [] + + children = [] + for c in range(count): + child, offset = read_string(bytes, offset) + children.append(child) + return children + + +class Sync(namedtuple('Sync', 'path')): + type = 9 + + def serialize(self): + return write_string(self.path) + + @classmethod + def deserialize(cls, buffer, offset): + return read_string(buffer, offset)[0] + + +class GetChildren2(namedtuple('GetChildren2', 'path watcher')): + type = 12 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend([1 if self.watcher else 0]) + return b + + @classmethod + def deserialize(cls, bytes, offset): + count = int_struct.unpack_from(bytes, offset)[0] + offset += int_struct.size + if count == -1: # pragma: nocover + return [] + + children = [] + for c in range(count): + child, offset = read_string(bytes, offset) + children.append(child) + stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + return children, stat + + +class CheckVersion(namedtuple('CheckVersion', 'path version')): + type = 13 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.path)) + b.extend(int_struct.pack(self.version)) + return b + + +class Transaction(namedtuple('Transaction', 'operations')): + type = 14 + + def serialize(self): + b = bytearray() + for op in self.operations: + b.extend(MultiHeader(op.type, False, -1).serialize() + + op.serialize()) + return b + multiheader_struct.pack(-1, True, -1) + + @classmethod + def deserialize(cls, bytes, offset): + header = MultiHeader(None, False, None) + results = [] + response = None + while not header.done: + if header.type == Create.type: + response, offset = read_string(bytes, offset) + elif header.type == Delete.type: + response = True + elif header.type == SetData.type: + response = ZnodeStat._make( + stat_struct.unpack_from(bytes, offset)) + offset += stat_struct.size + elif header.type == CheckVersion.type: + response = True + elif header.type == -1: + err = int_struct.unpack_from(bytes, offset)[0] + offset += int_struct.size + response = EXCEPTIONS[err]() + if response: + results.append(response) + header, offset = MultiHeader.deserialize(bytes, offset) + return results + + @staticmethod + def unchroot(client, response): + resp = [] + for result in response: + if isinstance(result, basestring): + resp.append(client.unchroot(result)) + else: + resp.append(result) + return resp + + +class Reconfig(namedtuple('Reconfig', 'joining leaving new_members config_id')): + type = 16 + + def serialize(self): + b = bytearray() + b.extend(write_string(self.joining)) + b.extend(write_string(self.leaving)) + b.extend(write_string(self.new_members)) + b.extend(long_struct.pack(self.config_id)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + data, offset = read_buffer(bytes, offset) + stat = ZnodeStat._make(stat_struct.unpack_from(bytes, offset)) + return data, stat + + +class Auth(namedtuple('Auth', 'auth_type scheme auth')): + type = 100 + + def serialize(self): + return (int_struct.pack(self.auth_type) + write_string(self.scheme) + + write_string(self.auth)) + + +class Watch(namedtuple('Watch', 'type state path')): + @classmethod + def deserialize(cls, bytes, offset): + """Given bytes and the current bytes offset, return the + type, state, path, and new offset""" + type, state = int_int_struct.unpack_from(bytes, offset) + offset += int_int_struct.size + path, offset = read_string(bytes, offset) + return cls(type, state, path), offset + + +class ReplyHeader(namedtuple('ReplyHeader', 'xid, zxid, err')): + @classmethod + def deserialize(cls, bytes, offset): + """Given bytes and the current bytes offset, return a + :class:`ReplyHeader` instance and the new offset""" + new_offset = offset + reply_header_struct.size + return cls._make( + reply_header_struct.unpack_from(bytes, offset)), new_offset + + +class MultiHeader(namedtuple('MultiHeader', 'type done err')): + def serialize(self): + b = bytearray() + b.extend(int_struct.pack(self.type)) + b.extend([1 if self.done else 0]) + b.extend(int_struct.pack(self.err)) + return b + + @classmethod + def deserialize(cls, bytes, offset): + t, done, err = multiheader_struct.unpack_from(bytes, offset) + offset += multiheader_struct.size + return cls(t, done is 1, err), offset diff --git a/yarn/src/main/python/task-starter/kazoo/protocol/states.py b/yarn/src/main/python/task-starter/kazoo/protocol/states.py new file mode 100644 index 00000000..395c013f --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/protocol/states.py @@ -0,0 +1,237 @@ +"""Kazoo State and Event objects""" +from collections import namedtuple + + +class KazooState(object): + """High level connection state values + + States inspired by Netflix Curator. + + .. attribute:: SUSPENDED + + The connection has been lost but may be recovered. We should + operate in a "safe mode" until then. When the connection is + resumed, it may be discovered that the session expired. A + client should not assume that locks are valid during this + time. + + .. attribute:: CONNECTED + + The connection is alive and well. + + .. attribute:: LOST + + The connection has been confirmed dead. Any ephemeral nodes + will need to be recreated upon re-establishing a connection. + If locks were acquired or recipes using ephemeral nodes are in + use, they can be considered lost as well. + + """ + SUSPENDED = "SUSPENDED" + CONNECTED = "CONNECTED" + LOST = "LOST" + + +class KeeperState(object): + """Zookeeper State + + Represents the Zookeeper state. Watch functions will receive a + :class:`KeeperState` attribute as their state argument. + + .. attribute:: AUTH_FAILED + + Authentication has failed, this is an unrecoverable error. + + .. attribute:: CONNECTED + + Zookeeper is connected. + + .. attribute:: CONNECTED_RO + + Zookeeper is connected in read-only state. + + .. attribute:: CONNECTING + + Zookeeper is currently attempting to establish a connection. + + .. attribute:: EXPIRED_SESSION + + The prior session was invalid, all prior ephemeral nodes are + gone. + + """ + AUTH_FAILED = 'AUTH_FAILED' + CONNECTED = 'CONNECTED' + CONNECTED_RO = 'CONNECTED_RO' + CONNECTING = 'CONNECTING' + CLOSED = 'CLOSED' + EXPIRED_SESSION = 'EXPIRED_SESSION' + + +class EventType(object): + """Zookeeper Event + + Represents a Zookeeper event. Events trigger watch functions which + will receive a :class:`EventType` attribute as their event + argument. + + .. attribute:: CREATED + + A node has been created. + + .. attribute:: DELETED + + A node has been deleted. + + .. attribute:: CHANGED + + The data for a node has changed. + + .. attribute:: CHILD + + The children under a node have changed (a child was added or + removed). This event does not indicate the data for a child + node has changed, which must have its own watch established. + + """ + CREATED = 'CREATED' + DELETED = 'DELETED' + CHANGED = 'CHANGED' + CHILD = 'CHILD' + +EVENT_TYPE_MAP = { + 1: EventType.CREATED, + 2: EventType.DELETED, + 3: EventType.CHANGED, + 4: EventType.CHILD +} + + +class WatchedEvent(namedtuple('WatchedEvent', ('type', 'state', 'path'))): + """A change on ZooKeeper that a Watcher is able to respond to. + + The :class:`WatchedEvent` includes exactly what happened, the + current state of ZooKeeper, and the path of the node that was + involved in the event. An instance of :class:`WatchedEvent` will be + passed to registered watch functions. + + .. attribute:: type + + A :class:`EventType` attribute indicating the event type. + + .. attribute:: state + + A :class:`KeeperState` attribute indicating the Zookeeper + state. + + .. attribute:: path + + The path of the node for the watch event. + + """ + + +class Callback(namedtuple('Callback', ('type', 'func', 'args'))): + """A callback that is handed to a handler for dispatch + + :param type: Type of the callback, currently is only 'watch' + :param func: Callback function + :param args: Argument list for the callback function + + """ + + +class ZnodeStat(namedtuple('ZnodeStat', 'czxid mzxid ctime mtime version' + ' cversion aversion ephemeralOwner dataLength' + ' numChildren pzxid')): + """A ZnodeStat structure with convenience properties + + When getting the value of a node from Zookeeper, the properties for + the node known as a "Stat structure" will be retrieved. The + :class:`ZnodeStat` object provides access to the standard Stat + properties and additional properties that are more readable and use + Python time semantics (seconds since epoch instead of ms). + + .. note:: + + The original Zookeeper Stat name is in parens next to the name + when it differs from the convenience attribute. These are **not + functions**, just attributes. + + .. attribute:: creation_transaction_id (czxid) + + The transaction id of the change that caused this znode to be + created. + + .. attribute:: last_modified_transaction_id (mzxid) + + The transaction id of the change that last modified this znode. + + .. attribute:: created (ctime) + + The time in seconds from epoch when this node was created. + (ctime is in milliseconds) + + .. attribute:: last_modified (mtime) + + The time in seconds from epoch when this znode was last + modified. (mtime is in milliseconds) + + .. attribute:: version + + The number of changes to the data of this znode. + + .. attribute:: acl_version (aversion) + + The number of changes to the ACL of this znode. + + .. attribute:: owner_session_id (ephemeralOwner) + + The session id of the owner of this znode if the znode is an + ephemeral node. If it is not an ephemeral node, it will be + `None`. (ephemeralOwner will be 0 if it is not ephemeral) + + .. attribute:: data_length (dataLength) + + The length of the data field of this znode. + + .. attribute:: children_count (numChildren) + + The number of children of this znode. + + """ + @property + def acl_version(self): + return self.aversion + + @property + def children_version(self): + return self.cversion + + @property + def created(self): + return self.ctime / 1000.0 + + @property + def last_modified(self): + return self.mtime / 1000.0 + + @property + def owner_session_id(self): + return self.ephemeralOwner or None + + @property + def creation_transaction_id(self): + return self.czxid + + @property + def last_modified_transaction_id(self): + return self.mzxid + + @property + def data_length(self): + return self.dataLength + + @property + def children_count(self): + return self.numChildren diff --git a/yarn/src/main/python/task-starter/kazoo/python2atexit.py b/yarn/src/main/python/task-starter/kazoo/python2atexit.py new file mode 100644 index 00000000..393e75ab --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/python2atexit.py @@ -0,0 +1,69 @@ +"""Uses the old atexit with added unregister for python 2.x +and the new atexit for python 3.x +""" + +import sys +import atexit + +__all__ = ["register", "unregister"] + + +_exithandlers = [] + + +def _run_exitfuncs(): + """run any registered exit functions + + _exithandlers is traversed in reverse order so functions are executed + last in, first out. + """ + + exc_info = None + while _exithandlers: + func, targs, kargs = _exithandlers.pop() + try: + func(*targs, **kargs) + except SystemExit: + exc_info = sys.exc_info() + except: + import traceback + sys.stderr.write("Error in atexit._run_exitfuncs:\n") + traceback.print_exc() + exc_info = sys.exc_info() + + if exc_info is not None: + raise exc_info[0](exc_info[1]) + + +def register(func, *targs, **kargs): + """register a function to be executed upon normal program termination + + func - function to be called at exit + targs - optional arguments to pass to func + kargs - optional keyword arguments to pass to func + + func is returned to facilitate usage as a decorator. + """ + if hasattr(atexit, "unregister"): + atexit.register(func, *targs, **kargs) + else: + _exithandlers.append((func, targs, kargs)) + return func + + +def unregister(func): + """remove func from the list of functions that are registered + doesn't do anything if func is not found + + func = function to be unregistered + """ + if hasattr(atexit, "unregister"): + atexit.unregister(func) + else: + handler_entries = [e for e in _exithandlers if e[0] == func] + for e in handler_entries: + _exithandlers.remove(e) + +if not hasattr(atexit, "unregister"): + # Only in python 2.x + atexit.register(_run_exitfuncs) diff --git a/yarn/src/main/python/task-starter/kazoo/recipe/__init__.py b/yarn/src/main/python/task-starter/kazoo/recipe/__init__.py new file mode 100644 index 00000000..792d6005 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/recipe/__init__.py @@ -0,0 +1 @@ +# diff --git a/yarn/src/main/python/task-starter/kazoo/recipe/barrier.py b/yarn/src/main/python/task-starter/kazoo/recipe/barrier.py new file mode 100644 index 00000000..f4b3f311 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/recipe/barrier.py @@ -0,0 +1,215 @@ +"""Zookeeper Barriers + +:Maintainer: None +:Status: Unknown + +""" +import os +import socket +import uuid + +from kazoo.protocol.states import EventType +from kazoo.exceptions import KazooException +from kazoo.exceptions import NoNodeError +from kazoo.exceptions import NodeExistsError + + +class Barrier(object): + """Kazoo Barrier + + Implements a barrier to block processing of a set of nodes until + a condition is met at which point the nodes will be allowed to + proceed. The barrier is in place if its node exists. + + .. warning:: + + The :meth:`wait` function does not handle connection loss and + may raise :exc:`~kazoo.exceptions.ConnectionLossException` if + the connection is lost while waiting. + + """ + def __init__(self, client, path): + """Create a Kazoo Barrier + + :param client: A :class:`~kazoo.client.KazooClient` instance. + :param path: The barrier path to use. + + """ + self.client = client + self.path = path + + def create(self): + """Establish the barrier if it doesn't exist already""" + self.client.retry(self.client.ensure_path, self.path) + + def remove(self): + """Remove the barrier + + :returns: Whether the barrier actually needed to be removed. + :rtype: bool + + """ + try: + self.client.retry(self.client.delete, self.path) + return True + except NoNodeError: + return False + + def wait(self, timeout=None): + """Wait on the barrier to be cleared + + :returns: True if the barrier has been cleared, otherwise + False. + :rtype: bool + + """ + cleared = self.client.handler.event_object() + + def wait_for_clear(event): + if event.type == EventType.DELETED: + cleared.set() + + exists = self.client.exists(self.path, watch=wait_for_clear) + if not exists: + return True + + cleared.wait(timeout) + return cleared.is_set() + + +class DoubleBarrier(object): + """Kazoo Double Barrier + + Double barriers are used to synchronize the beginning and end of + a distributed task. The barrier blocks when entering it until all + the members have joined, and blocks when leaving until all the + members have left. + + .. note:: + + You should register a listener for session loss as the process + will no longer be part of the barrier once the session is + gone. Connection losses will be retried with the default retry + policy. + + """ + def __init__(self, client, path, num_clients, identifier=None): + """Create a Double Barrier + + :param client: A :class:`~kazoo.client.KazooClient` instance. + :param path: The barrier path to use. + :param num_clients: How many clients must enter the barrier to + proceed. + :type num_clients: int + :param identifier: An identifier to use for this member of the + barrier when participating. Defaults to the + hostname + process id. + + """ + self.client = client + self.path = path + self.num_clients = num_clients + self._identifier = identifier or '%s-%s' % ( + socket.getfqdn(), os.getpid()) + self.participating = False + self.assured_path = False + self.node_name = uuid.uuid4().hex + self.create_path = self.path + "/" + self.node_name + + def enter(self): + """Enter the barrier, blocks until all nodes have entered""" + try: + self.client.retry(self._inner_enter) + self.participating = True + except KazooException: + # We failed to enter, best effort cleanup + self._best_effort_cleanup() + self.participating = False + + def _inner_enter(self): + # make sure our barrier parent node exists + if not self.assured_path: + self.client.ensure_path(self.path) + self.assured_path = True + + ready = self.client.handler.event_object() + + try: + self.client.create( + self.create_path, + self._identifier.encode('utf-8'), ephemeral=True) + except NodeExistsError: + pass + + def created(event): + if event.type == EventType.CREATED: + ready.set() + + self.client.exists(self.path + '/' + 'ready', watch=created) + + children = self.client.get_children(self.path) + + if len(children) < self.num_clients: + ready.wait() + else: + self.client.ensure_path(self.path + '/ready') + return True + + def leave(self): + """Leave the barrier, blocks until all nodes have left""" + try: + self.client.retry(self._inner_leave) + except KazooException: # pragma: nocover + # Failed to cleanly leave + self._best_effort_cleanup() + self.participating = False + + def _inner_leave(self): + # Delete the ready node if its around + try: + self.client.delete(self.path + '/ready') + except NoNodeError: + pass + + while True: + children = self.client.get_children(self.path) + if not children: + return True + + if len(children) == 1 and children[0] == self.node_name: + self.client.delete(self.create_path) + return True + + children.sort() + + ready = self.client.handler.event_object() + + def deleted(event): + if event.type == EventType.DELETED: + ready.set() + + if self.node_name == children[0]: + # We're first, wait on the highest to leave + if not self.client.exists(self.path + '/' + children[-1], + watch=deleted): + continue + + ready.wait() + continue + + # Delete our node + self.client.delete(self.create_path) + + # Wait on the first + if not self.client.exists(self.path + '/' + children[0], + watch=deleted): + continue + + # Wait for the lowest to be deleted + ready.wait() + + def _best_effort_cleanup(self): + try: + self.client.retry(self.client.delete, self.create_path) + except NoNodeError: + pass diff --git a/yarn/src/main/python/task-starter/kazoo/recipe/counter.py b/yarn/src/main/python/task-starter/kazoo/recipe/counter.py new file mode 100644 index 00000000..ed80f51b --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/recipe/counter.py @@ -0,0 +1,94 @@ +"""Zookeeper Counter + +:Maintainer: None +:Status: Unknown + +""" + +from kazoo.exceptions import BadVersionError +from kazoo.retry import ForceRetryError + + +class Counter(object): + """Kazoo Counter + + A shared counter of either int or float values. Changes to the + counter are done atomically. The general retry policy is used to + retry operations if concurrent changes are detected. + + The data is marshaled using `repr(value)` and converted back using + `type(counter.default)(value)` both using an ascii encoding. As + such other data types might be used for the counter value. + + Counter changes can raise + :class:`~kazoo.exceptions.BadVersionError` if the retry policy + wasn't able to apply a change. + + Example usage: + + .. code-block:: python + + zk = KazooClient() + counter = zk.Counter("/int") + counter += 2 + counter -= 1 + counter.value == 1 + + counter = zk.Counter("/float", default=1.0) + counter += 2.0 + counter.value == 3.0 + + """ + def __init__(self, client, path, default=0): + """Create a Kazoo Counter + + :param client: A :class:`~kazoo.client.KazooClient` instance. + :param path: The counter path to use. + :param default: The default value. + + """ + self.client = client + self.path = path + self.default = default + self.default_type = type(default) + self._ensured_path = False + + def _ensure_node(self): + if not self._ensured_path: + # make sure our node exists + self.client.ensure_path(self.path) + self._ensured_path = True + + def _value(self): + self._ensure_node() + old, stat = self.client.get(self.path) + old = old.decode('ascii') if old != b'' else self.default + version = stat.version + data = self.default_type(old) + return data, version + + @property + def value(self): + return self._value()[0] + + def _change(self, value): + if not isinstance(value, self.default_type): + raise TypeError('invalid type for value change') + self.client.retry(self._inner_change, value) + return self + + def _inner_change(self, value): + data, version = self._value() + data = repr(data + value).encode('ascii') + try: + self.client.set(self.path, data, version=version) + except BadVersionError: # pragma: nocover + raise ForceRetryError() + + def __add__(self, value): + """Add value to counter.""" + return self._change(value) + + def __sub__(self, value): + """Subtract value from counter.""" + return self._change(-value) diff --git a/yarn/src/main/python/task-starter/kazoo/recipe/election.py b/yarn/src/main/python/task-starter/kazoo/recipe/election.py new file mode 100644 index 00000000..3089fa69 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/recipe/election.py @@ -0,0 +1,79 @@ +"""ZooKeeper Leader Elections + +:Maintainer: None +:Status: Unknown + +""" +from kazoo.exceptions import CancelledError + + +class Election(object): + """Kazoo Basic Leader Election + + Example usage with a :class:`~kazoo.client.KazooClient` instance:: + + zk = KazooClient() + election = zk.Election("/electionpath", "my-identifier") + + # blocks until the election is won, then calls + # my_leader_function() + election.run(my_leader_function) + + """ + def __init__(self, client, path, identifier=None): + """Create a Kazoo Leader Election + + :param client: A :class:`~kazoo.client.KazooClient` instance. + :param path: The election path to use. + :param identifier: Name to use for this lock contender. This + can be useful for querying to see who the + current lock contenders are. + + """ + self.lock = client.Lock(path, identifier) + + def run(self, func, *args, **kwargs): + """Contend for the leadership + + This call will block until either this contender is cancelled + or this contender wins the election and the provided leadership + function subsequently returns or fails. + + :param func: A function to be called if/when the election is + won. + :param args: Arguments to leadership function. + :param kwargs: Keyword arguments to leadership function. + + """ + if not callable(func): + raise ValueError("leader function is not callable") + + try: + with self.lock: + func(*args, **kwargs) + + except CancelledError: + pass + + def cancel(self): + """Cancel participation in the election + + .. note:: + + If this contender has already been elected leader, this + method will not interrupt the leadership function. + + """ + self.lock.cancel() + + def contenders(self): + """Return an ordered list of the current contenders in the + election + + .. note:: + + If the contenders did not set an identifier, it will appear + as a blank string. + + """ + return self.lock.contenders() diff --git a/yarn/src/main/python/task-starter/kazoo/recipe/lease.py b/yarn/src/main/python/task-starter/kazoo/recipe/lease.py new file mode 100644 index 00000000..de68405f --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/recipe/lease.py @@ -0,0 +1,130 @@ +"""Zookeeper lease implementations + +:Maintainer: Lars Albertsson +:Maintainer: Jyrki Pulliainen +:Status: Beta + +""" + +import json +import socket +import datetime +from kazoo.exceptions import CancelledError + + +class NonBlockingLease(object): + """Exclusive lease that does not block. + + An exclusive lease ensures that only one client at a time owns the lease. The client may + renew the lease without losing it by obtaining a new lease with the same path and same + identity. The lease object evaluates to True if the lease was obtained. + + A common use case is a situation where a task should only run on a single host. In this + case, the clients that did not obtain the lease should exit without performing the protected + task. + + The lease stores time stamps using client clocks, and will therefore only work if client clocks + are roughly synchronised. It uses UTC, and works across time zones and daylight savings. + + Example usage: with a :class:`~kazoo.client.KazooClient` instance:: + + zk = KazooClient() + # Hold lease over an hour in order to keep job on same machine, with failover if it dies. + lease = zk.NonBlockingLease("/db_leases/hourly_cleanup", datetime.timedelta(minutes = 70), + identifier = "DB hourly cleanup on " + socket.gethostname()) + if lease: + do_hourly_database_cleanup() + """ + + # Bump when storage format changes + _version = 1 + _date_format = "%Y-%m-%dT%H:%M:%S" + _byte_encoding = 'utf-8' + + def __init__(self, client, path, duration, identifier=None, utcnow=datetime.datetime.utcnow): + """Create a non-blocking lease. + + :param client: A :class:`~kazoo.client.KazooClient` instance. + :param path: The lease path to use. + :param duration: Duration during which the lease is reserved. A :class:`~datetime.timedelta` instance. + :param identifier: Unique name to use for this lease holder. Reuse in order to renew the lease. + Defaults do :meth:`socket.gethostname()`. + :param utcnow: Clock function, by default returning :meth:`datetime.datetime.utcnow()`. Used for testing. + """ + ident = identifier or socket.gethostname() + self.obtained = False + self._attempt_obtaining(client, path, duration, ident, utcnow) + + def _attempt_obtaining(self, client, path, duration, ident, utcnow): + client.ensure_path(path) + holder_path = path + "/lease_holder" + lock = client.Lock(path, ident) + try: + with lock: + now = utcnow() + if client.exists(holder_path): + raw, _ = client.get(holder_path) + data = self._decode(raw) + if data["version"] != self._version: + # We need an upgrade, let someone else take the lease + return + current_end = datetime.datetime.strptime(data['end'], self._date_format) + if data['holder'] != ident and now < current_end: + # Another client is still holding the lease + return + client.delete(holder_path) + end_lease = (now + duration).strftime(self._date_format) + new_data = {'version': self._version, 'holder': ident, 'end': end_lease} + client.create(holder_path, self._encode(new_data)) + self.obtained = True + + except CancelledError: + pass + + def _encode(self, data_dict): + return json.dumps(data_dict).encode(self._byte_encoding) + + def _decode(self, raw): + return json.loads(raw.decode(self._byte_encoding)) + + # Python 2.x + def __nonzero__(self): + return self.obtained + + # Python 3.x + def __bool__(self): + return self.obtained + + +class MultiNonBlockingLease(object): + """Exclusive lease for multiple clients. + + This type of lease is useful when a limited set of hosts should run a particular task. + It will attempt to obtain leases trying a sequence of ZooKeeper lease paths. + + :param client: A :class:`~kazoo.client.KazooClient` instance. + :param count: Number of host leases allowed. + :param path: ZooKeeper path under which lease files are stored. + :param duration: Duration during which the lease is reserved. A :class:`~datetime.timedelta` instance. + :param identifier: Unique name to use for this lease holder. Reuse in order to renew the lease. + Defaults do :meth:`socket.gethostname()`. + :param utcnow: Clock function, by default returning :meth:`datetime.datetime.utcnow()`. Used for testing. + """ + + def __init__(self, client, count, path, duration, identifier=None, + utcnow=datetime.datetime.utcnow): + self.obtained = False + for num in range(count): + ls = NonBlockingLease(client, '%s/%d' % (path, num), duration, + identifier=identifier, utcnow=utcnow) + if ls: + self.obtained = True + break + + # Python 2.x + def __nonzero__(self): + return self.obtained + + # Python 3.x + def __bool__(self): + return self.obtained diff --git a/yarn/src/main/python/task-starter/kazoo/recipe/lock.py b/yarn/src/main/python/task-starter/kazoo/recipe/lock.py new file mode 100644 index 00000000..2b4fae89 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/recipe/lock.py @@ -0,0 +1,584 @@ +"""Zookeeper Locking Implementations + +:Maintainer: Ben Bangert +:Status: Production + +Error Handling +============== + +It's highly recommended to add a state listener with +:meth:`~KazooClient.add_listener` and watch for +:attr:`~KazooState.LOST` and :attr:`~KazooState.SUSPENDED` state +changes and re-act appropriately. In the event that a +:attr:`~KazooState.LOST` state occurs, its certain that the lock +and/or the lease has been lost. + +""" + +import sys +try: + from time import monotonic as now +except ImportError: + from time import time as now +import uuid + +import six + +from kazoo.retry import ( + KazooRetry, + RetryFailedError, + ForceRetryError +) +from kazoo.exceptions import CancelledError +from kazoo.exceptions import KazooException +from kazoo.exceptions import LockTimeout +from kazoo.exceptions import NoNodeError +from kazoo.protocol.states import KazooState + + +class _Watch(object): + def __init__(self, duration=None): + self.duration = duration + self.started_at = None + + def start(self): + self.started_at = now() + + def leftover(self): + if self.duration is None: + return None + else: + elapsed = now() - self.started_at + return max(0, self.duration - elapsed) + + +class Lock(object): + """Kazoo Lock + + Example usage with a :class:`~kazoo.client.KazooClient` instance: + + .. code-block:: python + + zk = KazooClient() + lock = zk.Lock("/lockpath", "my-identifier") + with lock: # blocks waiting for lock acquisition + # do something with the lock + + Note: This lock is not *re-entrant*. Repeated calls after already + acquired will block. + + """ + _NODE_NAME = '__lock__' + + def __init__(self, client, path, identifier=None): + """Create a Kazoo lock. + + :param client: A :class:`~kazoo.client.KazooClient` instance. + :param path: The lock path to use. + :param identifier: Name to use for this lock contender. This + can be useful for querying to see who the + current lock contenders are. + + """ + self.client = client + self.path = path + + # some data is written to the node. this can be queried via + # contenders() to see who is contending for the lock + self.data = str(identifier or "").encode('utf-8') + + self.wake_event = client.handler.event_object() + + # props to Netflix Curator for this trick. It is possible for our + # create request to succeed on the server, but for a failure to + # prevent us from getting back the full path name. We prefix our + # lock name with a uuid and can check for its presence on retry. + self.prefix = uuid.uuid4().hex + self._NODE_NAME + self.create_path = self.path + "/" + self.prefix + + self.create_tried = False + self.is_acquired = False + self.assured_path = False + self.cancelled = False + self._retry = KazooRetry(max_tries=None, + sleep_func=client.handler.sleep_func) + self._lock = client.handler.lock_object() + + def _ensure_path(self): + self.client.ensure_path(self.path) + self.assured_path = True + + def cancel(self): + """Cancel a pending lock acquire.""" + self.cancelled = True + self.wake_event.set() + + def acquire(self, blocking=True, timeout=None): + """ + Acquire the lock. By defaults blocks and waits forever. + + :param blocking: Block until lock is obtained or return immediately. + :type blocking: bool + :param timeout: Don't wait forever to acquire the lock. + :type timeout: float or None + + :returns: Was the lock acquired? + :rtype: bool + + :raises: :exc:`~kazoo.exceptions.LockTimeout` if the lock + wasn't acquired within `timeout` seconds. + + .. versionadded:: 1.1 + The timeout option. + """ + + def _acquire_lock(): + got_it = self._lock.acquire(False) + if not got_it: + raise ForceRetryError() + return True + + retry = self._retry.copy() + retry.deadline = timeout + + # Ensure we are locked so that we avoid multiple threads in + # this acquistion routine at the same time... + locked = self._lock.acquire(False) + if not locked and not blocking: + return False + if not locked: + # Lock acquire doesn't take a timeout, so simulate it... + try: + locked = retry(_acquire_lock) + except RetryFailedError: + return False + already_acquired = self.is_acquired + try: + gotten = False + try: + gotten = retry(self._inner_acquire, + blocking=blocking, timeout=timeout) + except RetryFailedError: + if not already_acquired: + self._best_effort_cleanup() + except KazooException: + # if we did ultimately fail, attempt to clean up + exc_info = sys.exc_info() + if not already_acquired: + self._best_effort_cleanup() + self.cancelled = False + six.reraise(exc_info[0], exc_info[1], exc_info[2]) + if gotten: + self.is_acquired = gotten + if not gotten and not already_acquired: + self._delete_node(self.node) + return gotten + finally: + self._lock.release() + + def _watch_session(self, state): + self.wake_event.set() + return True + + def _inner_acquire(self, blocking, timeout): + + # wait until it's our chance to get it.. + if self.is_acquired: + if not blocking: + return False + raise ForceRetryError() + + # make sure our election parent node exists + if not self.assured_path: + self._ensure_path() + + node = None + if self.create_tried: + node = self._find_node() + else: + self.create_tried = True + + if not node: + node = self.client.create(self.create_path, self.data, + ephemeral=True, sequence=True) + # strip off path to node + node = node[len(self.path) + 1:] + + self.node = node + + while True: + self.wake_event.clear() + + # bail out with an exception if cancellation has been requested + if self.cancelled: + raise CancelledError() + + children = self._get_sorted_children() + + try: + our_index = children.index(node) + except ValueError: # pragma: nocover + # somehow we aren't in the children -- probably we are + # recovering from a session failure and our ephemeral + # node was removed + raise ForceRetryError() + + if self.acquired_lock(children, our_index): + return True + + if not blocking: + return False + + # otherwise we are in the mix. watch predecessor and bide our time + predecessor = self.path + "/" + children[our_index - 1] + self.client.add_listener(self._watch_session) + try: + if self.client.exists(predecessor, self._watch_predecessor): + self.wake_event.wait(timeout) + if not self.wake_event.isSet(): + raise LockTimeout("Failed to acquire lock on %s after " + "%s seconds" % (self.path, timeout)) + finally: + self.client.remove_listener(self._watch_session) + + def acquired_lock(self, children, index): + return index == 0 + + def _watch_predecessor(self, event): + self.wake_event.set() + + def _get_sorted_children(self): + children = self.client.get_children(self.path) + + # can't just sort directly: the node names are prefixed by uuids + lockname = self._NODE_NAME + children.sort(key=lambda c: c[c.find(lockname) + len(lockname):]) + return children + + def _find_node(self): + children = self.client.get_children(self.path) + for child in children: + if child.startswith(self.prefix): + return child + return None + + def _delete_node(self, node): + self.client.delete(self.path + "/" + node) + + def _best_effort_cleanup(self): + try: + node = self._find_node() + if node: + self._delete_node(node) + except KazooException: # pragma: nocover + pass + + def release(self): + """Release the lock immediately.""" + return self.client.retry(self._inner_release) + + def _inner_release(self): + if not self.is_acquired: + return False + + try: + self._delete_node(self.node) + except NoNodeError: # pragma: nocover + pass + + self.is_acquired = False + self.node = None + return True + + def contenders(self): + """Return an ordered list of the current contenders for the + lock. + + .. note:: + + If the contenders did not set an identifier, it will appear + as a blank string. + + """ + # make sure our election parent node exists + if not self.assured_path: + self._ensure_path() + + children = self._get_sorted_children() + + contenders = [] + for child in children: + try: + data, stat = self.client.get(self.path + "/" + child) + contenders.append(data.decode('utf-8')) + except NoNodeError: # pragma: nocover + pass + return contenders + + def __enter__(self): + self.acquire() + + def __exit__(self, exc_type, exc_value, traceback): + self.release() + + +class Semaphore(object): + """A Zookeeper-based Semaphore + + This synchronization primitive operates in the same manner as the + Python threading version only uses the concept of leases to + indicate how many available leases are available for the lock + rather than counting. + + Note: This lock is not meant to be *re-entrant*. + + Example: + + .. code-block:: python + + zk = KazooClient() + semaphore = zk.Semaphore("/leasepath", "my-identifier") + with semaphore: # blocks waiting for lock acquisition + # do something with the semaphore + + .. warning:: + + This class stores the allowed max_leases as the data on the + top-level semaphore node. The stored value is checked once + against the max_leases of each instance. This check is + performed when acquire is called the first time. The semaphore + node needs to be deleted to change the allowed leases. + + .. versionadded:: 0.6 + The Semaphore class. + + .. versionadded:: 1.1 + The max_leases check. + + """ + def __init__(self, client, path, identifier=None, max_leases=1): + """Create a Kazoo Lock + + :param client: A :class:`~kazoo.client.KazooClient` instance. + :param path: The semaphore path to use. + :param identifier: Name to use for this lock contender. This + can be useful for querying to see who the + current lock contenders are. + :param max_leases: The maximum amount of leases available for + the semaphore. + + """ + # Implementation notes about how excessive thundering herd + # and watches are avoided + # - A node (lease pool) holds children for each lease in use + # - A lock is acquired for a process attempting to acquire a + # lease. If a lease is available, the ephemeral node is + # created in the lease pool and the lock is released. + # - Only the lock holder watches for children changes in the + # lease pool + self.client = client + self.path = path + + # some data is written to the node. this can be queried via + # contenders() to see who is contending for the lock + self.data = str(identifier or "").encode('utf-8') + self.max_leases = max_leases + self.wake_event = client.handler.event_object() + + self.create_path = self.path + "/" + uuid.uuid4().hex + self.lock_path = path + '-' + '__lock__' + self.is_acquired = False + self.assured_path = False + self.cancelled = False + self._session_expired = False + + def _ensure_path(self): + result = self.client.ensure_path(self.path) + self.assured_path = True + if result is True: + # node did already exist + data, _ = self.client.get(self.path) + try: + leases = int(data.decode('utf-8')) + except (ValueError, TypeError): + # ignore non-numeric data, maybe the node data is used + # for other purposes + pass + else: + if leases != self.max_leases: + raise ValueError( + "Inconsistent max leases: %s, expected: %s" % + (leases, self.max_leases) + ) + else: + self.client.set(self.path, str(self.max_leases).encode('utf-8')) + + def cancel(self): + """Cancel a pending semaphore acquire.""" + self.cancelled = True + self.wake_event.set() + + def acquire(self, blocking=True, timeout=None): + """Acquire the semaphore. By defaults blocks and waits forever. + + :param blocking: Block until semaphore is obtained or + return immediately. + :type blocking: bool + :param timeout: Don't wait forever to acquire the semaphore. + :type timeout: float or None + + :returns: Was the semaphore acquired? + :rtype: bool + + :raises: + ValueError if the max_leases value doesn't match the + stored value. + + :exc:`~kazoo.exceptions.LockTimeout` if the semaphore + wasn't acquired within `timeout` seconds. + + .. versionadded:: 1.1 + The blocking, timeout arguments and the max_leases check. + """ + # If the semaphore had previously been canceled, make sure to + # reset that state. + self.cancelled = False + + try: + self.is_acquired = self.client.retry( + self._inner_acquire, blocking=blocking, timeout=timeout) + except KazooException: + # if we did ultimately fail, attempt to clean up + self._best_effort_cleanup() + self.cancelled = False + raise + + return self.is_acquired + + def _inner_acquire(self, blocking, timeout=None): + """Inner loop that runs from the top anytime a command hits a + retryable Zookeeper exception.""" + self._session_expired = False + self.client.add_listener(self._watch_session) + + if not self.assured_path: + self._ensure_path() + + # Do we already have a lease? + if self.client.exists(self.create_path): + return True + + w = _Watch(duration=timeout) + w.start() + lock = self.client.Lock(self.lock_path, self.data) + gotten = lock.acquire(blocking=blocking, timeout=w.leftover()) + if not gotten: + return False + try: + while True: + self.wake_event.clear() + + # Attempt to grab our lease... + if self._get_lease(): + return True + + if blocking: + # If blocking, wait until self._watch_lease_change() is + # called before returning + self.wake_event.wait(w.leftover()) + if not self.wake_event.isSet(): + raise LockTimeout( + "Failed to acquire semaphore on %s " + "after %s seconds" % (self.path, timeout)) + else: + return False + finally: + lock.release() + + def _watch_lease_change(self, event): + self.wake_event.set() + + def _get_lease(self, data=None): + # Make sure the session is still valid + if self._session_expired: + raise ForceRetryError("Retry on session loss at top") + + # Make sure that the request hasn't been canceled + if self.cancelled: + raise CancelledError("Semaphore cancelled") + + # Get a list of the current potential lock holders. If they change, + # notify our wake_event object. This is used to unblock a blocking + # self._inner_acquire call. + children = self.client.get_children(self.path, + self._watch_lease_change) + + # If there are leases available, acquire one + if len(children) < self.max_leases: + self.client.create(self.create_path, self.data, ephemeral=True) + + # Check if our acquisition was successful or not. Update our state. + if self.client.exists(self.create_path): + self.is_acquired = True + else: + self.is_acquired = False + + # Return current state + return self.is_acquired + + def _watch_session(self, state): + if state == KazooState.LOST: + self._session_expired = True + self.wake_event.set() + + # Return true to de-register + return True + + def _best_effort_cleanup(self): + try: + self.client.delete(self.create_path) + except KazooException: # pragma: nocover + pass + + def release(self): + """Release the lease immediately.""" + return self.client.retry(self._inner_release) + + def _inner_release(self): + if not self.is_acquired: + return False + try: + self.client.delete(self.create_path) + except NoNodeError: # pragma: nocover + pass + self.is_acquired = False + return True + + def lease_holders(self): + """Return an unordered list of the current lease holders. + + .. note:: + + If the lease holder did not set an identifier, it will + appear as a blank string. + + """ + if not self.client.exists(self.path): + return [] + + children = self.client.get_children(self.path) + + lease_holders = [] + for child in children: + try: + data, stat = self.client.get(self.path + "/" + child) + lease_holders.append(data.decode('utf-8')) + except NoNodeError: # pragma: nocover + pass + return lease_holders + + def __enter__(self): + self.acquire() + + def __exit__(self, exc_type, exc_value, traceback): + self.release() diff --git a/yarn/src/main/python/task-starter/kazoo/recipe/partitioner.py b/yarn/src/main/python/task-starter/kazoo/recipe/partitioner.py new file mode 100644 index 00000000..79db8a43 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/recipe/partitioner.py @@ -0,0 +1,423 @@ +"""Zookeeper Partitioner Implementation + +:Maintainer: None +:Status: Unknown + +:class:`SetPartitioner` implements a partitioning scheme using +Zookeeper for dividing up resources amongst members of a party. + +This is useful when there is a set of resources that should only be +accessed by a single process at a time that multiple processes +across a cluster might want to divide up. + +Example Use-Case +---------------- + +- Multiple workers across a cluster need to divide up a list of queues + so that no two workers own the same queue. + +""" +import logging +import os +import socket +from functools import partial + +from kazoo.exceptions import KazooException, LockTimeout +from kazoo.protocol.states import KazooState +from kazoo.recipe.watchers import PatientChildrenWatch + +log = logging.getLogger(__name__) + + +class PartitionState(object): + """High level partition state values + + .. attribute:: ALLOCATING + + The set needs to be partitioned, and may require an existing + partition set to be released before acquiring a new partition + of the set. + + .. attribute:: ACQUIRED + + The set has been partitioned and acquired. + + .. attribute:: RELEASE + + The set needs to be repartitioned, and the current partitions + must be released before a new allocation can be made. + + .. attribute:: FAILURE + + The set partition has failed. This occurs when the maximum + time to partition the set is exceeded or the Zookeeper session + is lost. The partitioner is unusable after this state and must + be recreated. + + """ + ALLOCATING = "ALLOCATING" + ACQUIRED = "ACQUIRED" + RELEASE = "RELEASE" + FAILURE = "FAILURE" + + +class SetPartitioner(object): + """Partitions a set amongst members of a party + + This class will partition a set amongst members of a party such + that each member will be given zero or more items of the set and + each set item will be given to a single member. When new members + enter or leave the party, the set will be re-partitioned amongst + the members. + + When the :class:`SetPartitioner` enters the + :attr:`~PartitionState.FAILURE` state, it is unrecoverable + and a new :class:`SetPartitioner` should be created. + + Example: + + .. code-block:: python + + from kazoo.client import KazooClient + client = KazooClient() + + qp = client.SetPartitioner( + path='/work_queues', set=('queue-1', 'queue-2', 'queue-3')) + + while 1: + if qp.failed: + raise Exception("Lost or unable to acquire partition") + elif qp.release: + qp.release_set() + elif qp.acquired: + for partition in qp: + # Do something with each partition + elif qp.allocating: + qp.wait_for_acquire() + + **State Transitions** + + When created, the :class:`SetPartitioner` enters the + :attr:`PartitionState.ALLOCATING` state. + + :attr:`~PartitionState.ALLOCATING` -> + :attr:`~PartitionState.ACQUIRED` + + Set was partitioned successfully, the partition list assigned + is accessible via list/iter methods or calling list() on the + :class:`SetPartitioner` instance. + + :attr:`~PartitionState.ALLOCATING` -> + :attr:`~PartitionState.FAILURE` + + Allocating the set failed either due to a Zookeeper session + expiration, or failure to acquire the items of the set within + the timeout period. + + :attr:`~PartitionState.ACQUIRED` -> + :attr:`~PartitionState.RELEASE` + + The members of the party have changed, and the set needs to be + repartitioned. :meth:`SetPartitioner.release` should be called + as soon as possible. + + :attr:`~PartitionState.ACQUIRED` -> + :attr:`~PartitionState.FAILURE` + + The current partition was lost due to a Zookeeper session + expiration. + + :attr:`~PartitionState.RELEASE` -> + :attr:`~PartitionState.ALLOCATING` + + The current partition was released and is being re-allocated. + + """ + def __init__(self, client, path, set, partition_func=None, + identifier=None, time_boundary=30, max_reaction_time=1, + state_change_event=None): + """Create a :class:`~SetPartitioner` instance + + :param client: A :class:`~kazoo.client.KazooClient` instance. + :param path: The partition path to use. + :param set: The set of items to partition. + :param partition_func: A function to use to decide how to + partition the set. + :param identifier: An identifier to use for this member of the + party when participating. Defaults to the + hostname + process id. + :param time_boundary: How long the party members must be stable + before allocation can complete. + :param max_reaction_time: Maximum reaction time for party members + change. + :param state_change_event: An optional Event object that will be set + on every state change. + + """ + # Used to differentiate two states with the same names in time + self.state_id = 0 + self.state = PartitionState.ALLOCATING + self.state_change_event = state_change_event or \ + client.handler.event_object() + + self._client = client + self._path = path + self._set = set + self._partition_set = [] + self._partition_func = partition_func or self._partitioner + self._identifier = identifier or '%s-%s' % ( + socket.getfqdn(), os.getpid()) + self._locks = [] + self._lock_path = '/'.join([path, 'locks']) + self._party_path = '/'.join([path, 'party']) + self._time_boundary = time_boundary + self._max_reaction_time = max_reaction_time + + self._acquire_event = client.handler.event_object() + + # Create basic path nodes + client.ensure_path(path) + client.ensure_path(self._lock_path) + client.ensure_path(self._party_path) + + # Join the party + self._party = client.ShallowParty(self._party_path, + identifier=self._identifier) + self._party.join() + + self._state_change = client.handler.rlock_object() + client.add_listener(self._establish_sessionwatch) + + # Now watch the party and set the callback on the async result + # so we know when we're ready + self._child_watching(self._allocate_transition, async=True) + + def __iter__(self): + """Return the partitions in this partition set""" + for partition in self._partition_set: + yield partition + + @property + def failed(self): + """Corresponds to the :attr:`PartitionState.FAILURE` state""" + return self.state == PartitionState.FAILURE + + @property + def release(self): + """Corresponds to the :attr:`PartitionState.RELEASE` state""" + return self.state == PartitionState.RELEASE + + @property + def allocating(self): + """Corresponds to the :attr:`PartitionState.ALLOCATING` + state""" + return self.state == PartitionState.ALLOCATING + + @property + def acquired(self): + """Corresponds to the :attr:`PartitionState.ACQUIRED` state""" + return self.state == PartitionState.ACQUIRED + + def wait_for_acquire(self, timeout=30): + """Wait for the set to be partitioned and acquired + + :param timeout: How long to wait before returning. + :type timeout: int + + """ + self._acquire_event.wait(timeout) + + def release_set(self): + """Call to release the set + + This method begins the step of allocating once the set has + been released. + + """ + self._release_locks() + if self._locks: # pragma: nocover + # This shouldn't happen, it means we couldn't release our + # locks, abort + self._fail_out() + return + else: + with self._state_change: + if self.failed: + return + self._set_state(PartitionState.ALLOCATING) + self._child_watching(self._allocate_transition, async=True) + + def finish(self): + """Call to release the set and leave the party""" + self._release_locks() + self._fail_out() + + def _fail_out(self): + with self._state_change: + self._set_state(PartitionState.FAILURE) + if self._party.participating: + try: + self._party.leave() + except KazooException: # pragma: nocover + pass + + def _allocate_transition(self, result): + """Called when in allocating mode, and the children settled""" + + # Did we get an exception waiting for children to settle? + if result.exception: # pragma: nocover + self._fail_out() + return + + children, async_result = result.get() + children_changed = self._client.handler.event_object() + + def updated(result): + with self._state_change: + children_changed.set() + if self.acquired: + self._set_state(PartitionState.RELEASE) + + with self._state_change: + # We can lose connection during processing the event + if not self.allocating: + return + + # Remember the state ID to check later for race conditions + state_id = self.state_id + + # updated() will be called when children change + async_result.rawlink(updated) + + # Check whether the state has changed during the lock acquisition + # and abort the process if so. + def abort_if_needed(): + if self.state_id == state_id: + if children_changed.is_set(): + # The party has changed. Repartitioning... + self._abort_lock_acquisition() + return True + else: + return False + else: + if self.allocating or self.acquired: + # The connection was lost and user initiated a new + # allocation process. Abort it to eliminate race + # conditions with locks. + with self._state_change: + self._set_state(PartitionState.RELEASE) + + return True + + # Split up the set + partition_set = self._partition_func( + self._identifier, list(self._party), self._set) + + # Proceed to acquire locks for the working set as needed + for member in partition_set: + lock = self._client.Lock(self._lock_path + '/' + str(member)) + + while True: + try: + # We mustn't lock without timeout because in that case we + # can get a deadlock if the party state will change during + # lock acquisition. + lock.acquire(timeout=self._max_reaction_time) + except LockTimeout: + if abort_if_needed(): + return + except KazooException: + return self.finish() + else: + break + + self._locks.append(lock) + + if abort_if_needed(): + return + + # All locks acquired. Time for state transition. + with self._state_change: + if self.state_id == state_id and not children_changed.is_set(): + self._partition_set = partition_set + self._set_state(PartitionState.ACQUIRED) + self._acquire_event.set() + return + + if not abort_if_needed(): + # This mustn't happen. Means a logical error. + self._fail_out() + + def _release_locks(self): + """Attempt to completely remove all the locks""" + self._acquire_event.clear() + for lock in self._locks[:]: + try: + lock.release() + except KazooException: # pragma: nocover + # We proceed to remove as many as possible, and leave + # the ones we couldn't remove + pass + else: + self._locks.remove(lock) + + def _abort_lock_acquisition(self): + """Called during lock acquisition if a party change occurs""" + + self._release_locks() + + if self._locks: + # This shouldn't happen, it means we couldn't release our + # locks, abort + self._fail_out() + return + + self._child_watching(self._allocate_transition, async=True) + + def _child_watching(self, func=None, async=False): + """Called when children are being watched to stabilize + + This actually returns immediately, child watcher spins up a + new thread/greenlet and waits for it to stabilize before + any callbacks might run. + + """ + watcher = PatientChildrenWatch(self._client, self._party_path, + self._time_boundary) + asy = watcher.start() + if func is not None: + # We spin up the function in a separate thread/greenlet + # to ensure that the rawlink's it might use won't be + # blocked + if async: + func = partial(self._client.handler.spawn, func) + asy.rawlink(func) + return asy + + def _establish_sessionwatch(self, state): + """Register ourself to listen for session events, we shut down + if we become lost""" + with self._state_change: + if self.failed: + pass + elif state == KazooState.LOST: + self._client.handler.spawn(self._fail_out) + elif not self.release: + self._set_state(PartitionState.RELEASE) + + return state == KazooState.LOST + + def _partitioner(self, identifier, members, partitions): + # Ensure consistent order of partitions/members + all_partitions = sorted(partitions) + workers = sorted(members) + + i = workers.index(identifier) + # Now return the partition list starting at our location and + # skipping the other workers + return all_partitions[i::len(workers)] + + def _set_state(self, state): + self.state = state + self.state_id += 1 + self.state_change_event.set() diff --git a/yarn/src/main/python/task-starter/kazoo/recipe/party.py b/yarn/src/main/python/task-starter/kazoo/recipe/party.py new file mode 100644 index 00000000..7186c10a --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/recipe/party.py @@ -0,0 +1,118 @@ +"""Party + +:Maintainer: Ben Bangert +:Status: Production + +A Zookeeper pool of party members. The :class:`Party` object can be +used for determining members of a party. + +""" +import uuid + +from kazoo.exceptions import NodeExistsError, NoNodeError + + +class BaseParty(object): + """Base implementation of a party.""" + def __init__(self, client, path, identifier=None): + """ + :param client: A :class:`~kazoo.client.KazooClient` instance. + :param path: The party path to use. + :param identifier: An identifier to use for this member of the + party when participating. + + """ + self.client = client + self.path = path + self.data = str(identifier or "").encode('utf-8') + self.ensured_path = False + self.participating = False + + def _ensure_parent(self): + if not self.ensured_path: + # make sure our parent node exists + self.client.ensure_path(self.path) + self.ensured_path = True + + def join(self): + """Join the party""" + return self.client.retry(self._inner_join) + + def _inner_join(self): + self._ensure_parent() + try: + self.client.create(self.create_path, self.data, ephemeral=True) + self.participating = True + except NodeExistsError: + # node was already created, perhaps we are recovering from a + # suspended connection + self.participating = True + + def leave(self): + """Leave the party""" + self.participating = False + return self.client.retry(self._inner_leave) + + def _inner_leave(self): + try: + self.client.delete(self.create_path) + except NoNodeError: + return False + return True + + def __len__(self): + """Return a count of participating clients""" + self._ensure_parent() + return len(self._get_children()) + + def _get_children(self): + return self.client.retry(self.client.get_children, self.path) + + +class Party(BaseParty): + """Simple pool of participating processes""" + _NODE_NAME = "__party__" + + def __init__(self, client, path, identifier=None): + BaseParty.__init__(self, client, path, identifier=identifier) + self.node = uuid.uuid4().hex + self._NODE_NAME + self.create_path = self.path + "/" + self.node + + def __iter__(self): + """Get a list of participating clients' data values""" + self._ensure_parent() + children = self._get_children() + for child in children: + try: + d, _ = self.client.retry(self.client.get, self.path + + "/" + child) + yield d.decode('utf-8') + except NoNodeError: # pragma: nocover + pass + + def _get_children(self): + children = BaseParty._get_children(self) + return [c for c in children if self._NODE_NAME in c] + + +class ShallowParty(BaseParty): + """Simple shallow pool of participating processes + + This differs from the :class:`Party` as the identifier is used in + the name of the party node itself, rather than the data. This + places some restrictions on the length as it must be a valid + Zookeeper node (an alphanumeric string), but reduces the overhead + of getting a list of participants to a single Zookeeper call. + + """ + def __init__(self, client, path, identifier=None): + BaseParty.__init__(self, client, path, identifier=identifier) + self.node = '-'.join([uuid.uuid4().hex, self.data.decode('utf-8')]) + self.create_path = self.path + "/" + self.node + + def __iter__(self): + """Get a list of participating clients' identifiers""" + self._ensure_parent() + children = self._get_children() + for child in children: + yield child[child.find('-') + 1:] diff --git a/yarn/src/main/python/task-starter/kazoo/recipe/queue.py b/yarn/src/main/python/task-starter/kazoo/recipe/queue.py new file mode 100644 index 00000000..e57873c6 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/recipe/queue.py @@ -0,0 +1,330 @@ +"""Zookeeper based queue implementations. + +:Maintainer: None +:Status: Possibly Buggy + +.. note:: + + This queue was reported to cause memory leaks over long running periods. + See: https://github.com/python-zk/kazoo/issues/175 + +""" + +import uuid +from kazoo.exceptions import NoNodeError, NodeExistsError +from kazoo.retry import ForceRetryError +from kazoo.protocol.states import EventType + + +class BaseQueue(object): + """A common base class for queue implementations.""" + + def __init__(self, client, path): + """ + :param client: A :class:`~kazoo.client.KazooClient` instance. + :param path: The queue path to use in ZooKeeper. + """ + self.client = client + self.path = path + self._entries_path = path + self.structure_paths = (self.path, ) + self.ensured_path = False + + def _check_put_arguments(self, value, priority=100): + if not isinstance(value, bytes): + raise TypeError("value must be a byte string") + if not isinstance(priority, int): + raise TypeError("priority must be an int") + elif priority < 0 or priority > 999: + raise ValueError("priority must be between 0 and 999") + + def _ensure_paths(self): + if not self.ensured_path: + # make sure our parent / internal structure nodes exists + for path in self.structure_paths: + self.client.ensure_path(path) + self.ensured_path = True + + def __len__(self): + self._ensure_paths() + _, stat = self.client.retry(self.client.get, self._entries_path) + return stat.children_count + + +class Queue(BaseQueue): + """A distributed queue with optional priority support. + + This queue does not offer reliable consumption. An entry is removed + from the queue prior to being processed. So if an error occurs, the + consumer has to re-queue the item or it will be lost. + + """ + + prefix = "entry-" + + def __init__(self, client, path): + """ + :param client: A :class:`~kazoo.client.KazooClient` instance. + :param path: The queue path to use in ZooKeeper. + """ + super(Queue, self).__init__(client, path) + self._children = [] + + def __len__(self): + """Return queue size.""" + return super(Queue, self).__len__() + + def get(self): + """ + Get item data and remove an item from the queue. + + :returns: Item data or None. + :rtype: bytes + """ + self._ensure_paths() + return self.client.retry(self._inner_get) + + def _inner_get(self): + if not self._children: + self._children = self.client.retry( + self.client.get_children, self.path) + self._children = sorted(self._children) + if not self._children: + return None + name = self._children[0] + try: + data, stat = self.client.get(self.path + "/" + name) + except NoNodeError: # pragma: nocover + # the first node has vanished in the meantime, try to + # get another one + raise ForceRetryError() + try: + self.client.delete(self.path + "/" + name) + except NoNodeError: # pragma: nocover + # we were able to get the data but someone else has removed + # the node in the meantime. consider the item as processed + # by the other process + raise ForceRetryError() + self._children.pop(0) + return data + + def put(self, value, priority=100): + """Put an item into the queue. + + :param value: Byte string to put into the queue. + :param priority: + An optional priority as an integer with at most 3 digits. + Lower values signify higher priority. + """ + self._check_put_arguments(value, priority) + self._ensure_paths() + path = '{path}/{prefix}{priority:03d}-'.format( + path=self.path, prefix=self.prefix, priority=priority) + self.client.create(path, value, sequence=True) + + +class LockingQueue(BaseQueue): + """A distributed queue with priority and locking support. + + Upon retrieving an entry from the queue, the entry gets locked with an + ephemeral node (instead of deleted). If an error occurs, this lock gets + released so that others could retake the entry. This adds a little penalty + as compared to :class:`Queue` implementation. + + The user should call the :meth:`LockingQueue.get` method first to lock and + retrieve the next entry. When finished processing the entry, a user should + call the :meth:`LockingQueue.consume` method that will remove the entry + from the queue. + + This queue will not track connection status with ZooKeeper. If a node locks + an element, then loses connection with ZooKeeper and later reconnects, the + lock will probably be removed by Zookeeper in the meantime, but a node + would still think that it holds a lock. The user should check the + connection status with Zookeeper or call :meth:`LockingQueue.holds_lock` + method that will check if a node still holds the lock. + + .. note:: + :class:`LockingQueue` requires ZooKeeper 3.4 or above, since it is + using transactions. + """ + lock = "/taken" + entries = "/entries" + entry = "entry" + + def __init__(self, client, path): + """ + :param client: A :class:`~kazoo.client.KazooClient` instance. + :param path: The queue path to use in ZooKeeper. + """ + super(LockingQueue, self).__init__(client, path) + self.id = uuid.uuid4().hex.encode() + self.processing_element = None + self._lock_path = self.path + self.lock + self._entries_path = self.path + self.entries + self.structure_paths = (self._lock_path, self._entries_path) + + def __len__(self): + """Returns the current length of the queue. + + :returns: queue size (includes locked entries count). + """ + return super(LockingQueue, self).__len__() + + def put(self, value, priority=100): + """Put an entry into the queue. + + :param value: Byte string to put into the queue. + :param priority: + An optional priority as an integer with at most 3 digits. + Lower values signify higher priority. + + """ + self._check_put_arguments(value, priority) + self._ensure_paths() + + self.client.create( + "{path}/{prefix}-{priority:03d}-".format( + path=self._entries_path, + prefix=self.entry, + priority=priority), + value, sequence=True) + + def put_all(self, values, priority=100): + """Put several entries into the queue. The action only succeeds + if all entries where put into the queue. + + :param values: A list of values to put into the queue. + :param priority: + An optional priority as an integer with at most 3 digits. + Lower values signify higher priority. + + """ + if not isinstance(values, list): + raise TypeError("values must be a list of byte strings") + if not isinstance(priority, int): + raise TypeError("priority must be an int") + elif priority < 0 or priority > 999: + raise ValueError("priority must be between 0 and 999") + self._ensure_paths() + + with self.client.transaction() as transaction: + for value in values: + if not isinstance(value, bytes): + raise TypeError("value must be a byte string") + transaction.create( + "{path}/{prefix}-{priority:03d}-".format( + path=self._entries_path, + prefix=self.entry, + priority=priority), + value, sequence=True) + + def get(self, timeout=None): + """Locks and gets an entry from the queue. If a previously got entry + was not consumed, this method will return that entry. + + :param timeout: + Maximum waiting time in seconds. If None then it will wait + untill an entry appears in the queue. + :returns: A locked entry value or None if the timeout was reached. + :rtype: bytes + """ + self._ensure_paths() + if self.processing_element is not None: + return self.processing_element[1] + else: + return self._inner_get(timeout) + + def holds_lock(self): + """Checks if a node still holds the lock. + + :returns: True if a node still holds the lock, False otherwise. + :rtype: bool + """ + if self.processing_element is None: + return False + lock_id, _ = self.processing_element + lock_path = "{path}/{id}".format(path=self._lock_path, id=lock_id) + self.client.sync(lock_path) + value, stat = self.client.retry(self.client.get, lock_path) + return value == self.id + + def consume(self): + """Removes a currently processing entry from the queue. + + :returns: True if element was removed successfully, False otherwise. + :rtype: bool + """ + if self.processing_element is not None and self.holds_lock(): + id_, value = self.processing_element + with self.client.transaction() as transaction: + transaction.delete("{path}/{id}".format( + path=self._entries_path, + id=id_)) + transaction.delete("{path}/{id}".format( + path=self._lock_path, + id=id_)) + self.processing_element = None + return True + else: + return False + + def _inner_get(self, timeout): + flag = self.client.handler.event_object() + lock = self.client.handler.lock_object() + canceled = False + value = [] + + def check_for_updates(event): + if event is not None and event.type != EventType.CHILD: + return + with lock: + if canceled or flag.isSet(): + return + values = self.client.retry( + self.client.get_children, + self._entries_path, + check_for_updates) + taken = self.client.retry( + self.client.get_children, + self._lock_path, + check_for_updates) + available = self._filter_locked(values, taken) + if len(available) > 0: + ret = self._take(available[0]) + if ret is not None: + # By this time, no one took the task + value.append(ret) + flag.set() + + check_for_updates(None) + retVal = None + flag.wait(timeout) + with lock: + canceled = True + if len(value) > 0: + # We successfully locked an entry + self.processing_element = value[0] + retVal = value[0][1] + return retVal + + def _filter_locked(self, values, taken): + taken = set(taken) + available = sorted(values) + return (available if len(taken) == 0 else + [x for x in available if x not in taken]) + + def _take(self, id_): + try: + self.client.create( + "{path}/{id}".format( + path=self._lock_path, + id=id_), + self.id, + ephemeral=True) + value, stat = self.client.retry( + self.client.get, + "{path}/{id}".format(path=self._entries_path, id=id_)) + except (NoNodeError, NodeExistsError): + # Item is already consumed or locked + return None + return (id_, value) diff --git a/yarn/src/main/python/task-starter/kazoo/recipe/watchers.py b/yarn/src/main/python/task-starter/kazoo/recipe/watchers.py new file mode 100644 index 00000000..ad585da1 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/recipe/watchers.py @@ -0,0 +1,419 @@ +"""Higher level child and data watching API's. + +:Maintainer: Ben Bangert +:Status: Production + +.. note:: + + :ref:`DataWatch` and :ref:`ChildrenWatch` may only handle a single + function, attempts to associate a single instance with multiple functions + will result in an exception being thrown. + +""" +import logging +import time +import warnings +from functools import partial, wraps + +from kazoo.retry import KazooRetry +from kazoo.exceptions import ( + ConnectionClosedError, + NoNodeError, + KazooException +) +from kazoo.protocol.states import KazooState + +log = logging.getLogger(__name__) + + +_STOP_WATCHING = object() + + +def _ignore_closed(func): + @wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except ConnectionClosedError: + pass + return wrapper + + +class DataWatch(object): + """Watches a node for data updates and calls the specified + function each time it changes + + The function will also be called the very first time its + registered to get the data. + + Returning `False` from the registered function will disable future + data change calls. If the client connection is closed (using the + close command), the DataWatch will no longer get updates. + + If the function supplied takes three arguments, then the third one + will be a :class:`~kazoo.protocol.states.WatchedEvent`. It will + only be set if the change to the data occurs as a result of the + server notifying the watch that there has been a change. Events + like reconnection or the first call will not include an event. + + If the node does not exist, then the function will be called with + ``None`` for all values. + + .. tip:: + + Because :class:`DataWatch` can watch nodes that don't exist, it + can be used alternatively as a higher-level Exists watcher that + survives reconnections and session loss. + + Example with client: + + .. code-block:: python + + @client.DataWatch('/path/to/watch') + def my_func(data, stat): + print("Data is %s" % data) + print("Version is %s" % stat.version) + + # Above function is called immediately and prints + + # Or if you want the event object + @client.DataWatch('/path/to/watch') + def my_func(data, stat, event): + print("Data is %s" % data) + print("Version is %s" % stat.version) + print("Event is %s" % event) + + .. versionchanged:: 1.2 + + DataWatch now ignores additional arguments that were previously + passed to it and warns that they are no longer respected. + + """ + def __init__(self, client, path, func=None, *args, **kwargs): + """Create a data watcher for a path + + :param client: A zookeeper client. + :type client: :class:`~kazoo.client.KazooClient` + :param path: The path to watch for data changes on. + :type path: str + :param func: Function to call initially and every time the + node changes. `func` will be called with a + tuple, the value of the node and a + :class:`~kazoo.client.ZnodeStat` instance. + :type func: callable + + """ + self._client = client + self._path = path + self._func = func + self._stopped = False + self._run_lock = client.handler.lock_object() + self._version = None + self._retry = KazooRetry(max_tries=None, + sleep_func=client.handler.sleep_func) + self._include_event = None + self._ever_called = False + self._used = False + + if args or kwargs: + warnings.warn('Passing additional arguments to DataWatch is' + ' deprecated. ignore_missing_node is now assumed ' + ' to be True by default, and the event will be ' + ' sent if the function can handle receiving it', + DeprecationWarning, stacklevel=2) + + # Register our session listener if we're going to resume + # across session losses + if func is not None: + self._used = True + self._client.add_listener(self._session_watcher) + self._get_data() + + def __call__(self, func): + """Callable version for use as a decorator + + :param func: Function to call initially and every time the + data changes. `func` will be called with a + tuple, the value of the node and a + :class:`~kazoo.client.ZnodeStat` instance. + :type func: callable + + """ + if self._used: + raise KazooException( + "A function has already been associated with this " + "DataWatch instance.") + + self._func = func + + self._used = True + self._client.add_listener(self._session_watcher) + self._get_data() + return func + + def _log_func_exception(self, data, stat, event=None): + try: + # For backwards compatibility, don't send event to the + # callback unless the send_event is set in constructor + if not self._ever_called: + self._ever_called = True + try: + result = self._func(data, stat, event) + except TypeError: + result = self._func(data, stat) + if result is False: + self._stopped = True + self._client.remove_listener(self._session_watcher) + except Exception as exc: + log.exception(exc) + raise + + @_ignore_closed + def _get_data(self, event=None): + # Ensure this runs one at a time, possible because the session + # watcher may trigger a run + with self._run_lock: + if self._stopped: + return + + initial_version = self._version + + try: + data, stat = self._retry(self._client.get, + self._path, self._watcher) + except NoNodeError: + data = None + + # This will set 'stat' to None if the node does not yet + # exist. + stat = self._retry(self._client.exists, self._path, + self._watcher) + if stat: + self._client.handler.spawn(self._get_data) + return + + # No node data, clear out version + if stat is None: + self._version = None + else: + self._version = stat.mzxid + + # Call our function if its the first time ever, or if the + # version has changed + if initial_version != self._version or not self._ever_called: + self._log_func_exception(data, stat, event) + + def _watcher(self, event): + self._get_data(event=event) + + def _set_watch(self, state): + with self._run_lock: + self._watch_established = state + + def _session_watcher(self, state): + if state == KazooState.CONNECTED: + self._client.handler.spawn(self._get_data) + + +class ChildrenWatch(object): + """Watches a node for children updates and calls the specified + function each time it changes + + The function will also be called the very first time its + registered to get children. + + Returning `False` from the registered function will disable future + children change calls. If the client connection is closed (using + the close command), the ChildrenWatch will no longer get updates. + + if send_event=True in __init__, then the function will always be + called with second parameter, ``event``. Upon initial call or when + recovering a lost session the ``event`` is always ``None``. + Otherwise it's a :class:`~kazoo.prototype.state.WatchedEvent` + instance. + + Example with client: + + .. code-block:: python + + @client.ChildrenWatch('/path/to/watch') + def my_func(children): + print "Children are %s" % children + + # Above function is called immediately and prints children + + """ + def __init__(self, client, path, func=None, + allow_session_lost=True, send_event=False): + """Create a children watcher for a path + + :param client: A zookeeper client. + :type client: :class:`~kazoo.client.KazooClient` + :param path: The path to watch for children on. + :type path: str + :param func: Function to call initially and every time the + children change. `func` will be called with a + single argument, the list of children. + :type func: callable + :param allow_session_lost: Whether the watch should be + re-registered if the zookeeper + session is lost. + :type allow_session_lost: bool + :type send_event: bool + :param send_event: Whether the function should be passed the + event sent by ZooKeeper or None upon + initialization (see class documentation) + + The path must already exist for the children watcher to + run. + + """ + self._client = client + self._path = path + self._func = func + self._send_event = send_event + self._stopped = False + self._watch_established = False + self._allow_session_lost = allow_session_lost + self._run_lock = client.handler.lock_object() + self._prior_children = None + self._used = False + + # Register our session listener if we're going to resume + # across session losses + if func is not None: + self._used = True + if allow_session_lost: + self._client.add_listener(self._session_watcher) + self._get_children() + + def __call__(self, func): + """Callable version for use as a decorator + + :param func: Function to call initially and every time the + children change. `func` will be called with a + single argument, the list of children. + :type func: callable + + """ + if self._used: + raise KazooException( + "A function has already been associated with this " + "ChildrenWatch instance.") + + self._func = func + + self._used = True + if self._allow_session_lost: + self._client.add_listener(self._session_watcher) + self._get_children() + return func + + @_ignore_closed + def _get_children(self, event=None): + with self._run_lock: # Ensure this runs one at a time + if self._stopped: + return + + children = self._client.retry(self._client.get_children, + self._path, self._watcher) + if not self._watch_established: + self._watch_established = True + + if self._prior_children is not None and \ + self._prior_children == children: + return + + self._prior_children = children + + try: + if self._send_event: + result = self._func(children, event) + else: + result = self._func(children) + if result is False: + self._stopped = True + except Exception as exc: + log.exception(exc) + raise + + def _watcher(self, event): + self._get_children(event) + + def _session_watcher(self, state): + if state in (KazooState.LOST, KazooState.SUSPENDED): + self._watch_established = False + elif (state == KazooState.CONNECTED and + not self._watch_established and not self._stopped): + self._client.handler.spawn(self._get_children) + + +class PatientChildrenWatch(object): + """Patient Children Watch that returns values after the children + of a node don't change for a period of time + + A separate watcher for the children of a node, that ignores + changes within a boundary time and sets the result only when the + boundary time has elapsed with no children changes. + + Example:: + + watcher = PatientChildrenWatch(client, '/some/path', + time_boundary=5) + async_object = watcher.start() + + # Blocks until the children have not changed for time boundary + # (5 in this case) seconds, returns children list and an + # async_result that will be set if the children change in the + # future + children, child_async = async_object.get() + + .. note:: + + This Watch is different from :class:`DataWatch` and + :class:`ChildrenWatch` as it only returns once, does not take + a function that is called, and provides an + :class:`~kazoo.interfaces.IAsyncResult` object that can be + checked to see if the children have changed later. + + """ + def __init__(self, client, path, time_boundary=30): + self.client = client + self.path = path + self.children = [] + self.time_boundary = time_boundary + self.children_changed = client.handler.event_object() + + def start(self): + """Begin the watching process asynchronously + + :returns: An :class:`~kazoo.interfaces.IAsyncResult` instance + that will be set when no change has occurred to the + children for time boundary seconds. + + """ + self.asy = asy = self.client.handler.async_result() + self.client.handler.spawn(self._inner_start) + return asy + + def _inner_start(self): + try: + while True: + async_result = self.client.handler.async_result() + self.children = self.client.retry( + self.client.get_children, self.path, + partial(self._children_watcher, async_result)) + self.client.handler.sleep_func(self.time_boundary) + + if self.children_changed.is_set(): + self.children_changed.clear() + else: + break + + self.asy.set((self.children, async_result)) + except Exception as exc: + self.asy.set_exception(exc) + + def _children_watcher(self, async, event): + self.children_changed.set() + async.set(time.time()) diff --git a/yarn/src/main/python/task-starter/kazoo/retry.py b/yarn/src/main/python/task-starter/kazoo/retry.py new file mode 100644 index 00000000..4926882a --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/retry.py @@ -0,0 +1,153 @@ +import logging +import random +import time + +from kazoo.exceptions import ( + ConnectionClosedError, + ConnectionLoss, + KazooException, + OperationTimeoutError, + SessionExpiredError, +) + +log = logging.getLogger(__name__) + + +class ForceRetryError(Exception): + """Raised when some recipe logic wants to force a retry.""" + + +class RetryFailedError(KazooException): + """Raised when retrying an operation ultimately failed, after + retrying the maximum number of attempts. + """ + + +class InterruptedError(RetryFailedError): + """Raised when the retry is forcibly interrupted by the interrupt + function""" + + +class KazooRetry(object): + """Helper for retrying a method in the face of retry-able + exceptions""" + RETRY_EXCEPTIONS = ( + ConnectionLoss, + OperationTimeoutError, + ForceRetryError + ) + + EXPIRED_EXCEPTIONS = ( + SessionExpiredError, + ) + + def __init__(self, max_tries=1, delay=0.1, backoff=2, max_jitter=0.8, + max_delay=3600, ignore_expire=True, sleep_func=time.sleep, + deadline=None, interrupt=None): + """Create a :class:`KazooRetry` instance for retrying function + calls + + :param max_tries: How many times to retry the command. -1 means + infinite tries. + :param delay: Initial delay between retry attempts. + :param backoff: Backoff multiplier between retry attempts. + Defaults to 2 for exponential backoff. + :param max_jitter: Additional max jitter period to wait between + retry attempts to avoid slamming the server. + :param max_delay: Maximum delay in seconds, regardless of other + backoff settings. Defaults to one hour. + :param ignore_expire: + Whether a session expiration should be ignored and treated + as a retry-able command. + :param interrupt: + Function that will be called with no args that may return + True if the retry should be ceased immediately. This will + be called no more than every 0.1 seconds during a wait + between retries. + + """ + self.max_tries = max_tries + self.delay = delay + self.backoff = backoff + self.max_jitter = int(max_jitter * 100) + self.max_delay = float(max_delay) + self._attempts = 0 + self._cur_delay = delay + self.deadline = deadline + self._cur_stoptime = None + self.sleep_func = sleep_func + self.retry_exceptions = self.RETRY_EXCEPTIONS + self.interrupt = interrupt + if ignore_expire: + self.retry_exceptions += self.EXPIRED_EXCEPTIONS + + def reset(self): + """Reset the attempt counter""" + self._attempts = 0 + self._cur_delay = self.delay + self._cur_stoptime = None + + def copy(self): + """Return a clone of this retry manager""" + obj = KazooRetry(max_tries=self.max_tries, + delay=self.delay, + backoff=self.backoff, + max_jitter=self.max_jitter / 100.0, + max_delay=self.max_delay, + sleep_func=self.sleep_func, + deadline=self.deadline, + interrupt=self.interrupt) + obj.retry_exceptions = self.retry_exceptions + return obj + + def __call__(self, func, *args, **kwargs): + """Call a function with arguments until it completes without + throwing a Kazoo exception + + :param func: Function to call + :param args: Positional arguments to call the function with + :params kwargs: Keyword arguments to call the function with + + The function will be called until it doesn't throw one of the + retryable exceptions (ConnectionLoss, OperationTimeout, or + ForceRetryError), and optionally retrying on session + expiration. + + """ + self.reset() + + while True: + try: + if self.deadline is not None and self._cur_stoptime is None: + self._cur_stoptime = time.time() + self.deadline + return func(*args, **kwargs) + except ConnectionClosedError: + raise + except self.retry_exceptions: + # Note: max_tries == -1 means infinite tries. + if self._attempts == self.max_tries: + raise RetryFailedError("Too many retry attempts") + self._attempts += 1 + sleeptime = self._cur_delay + ( + random.randint(0, self.max_jitter) / 100.0) + + if self._cur_stoptime is not None and \ + time.time() + sleeptime >= self._cur_stoptime: + raise RetryFailedError("Exceeded retry deadline") + + if self.interrupt: + while sleeptime > 0: + # Break the time period down and sleep for no + # longer than 0.1 before calling the interrupt + if sleeptime < 0.1: + self.sleep_func(sleeptime) + sleeptime -= sleeptime + else: + self.sleep_func(0.1) + sleeptime -= 0.1 + if self.interrupt(): + raise InterruptedError() + else: + self.sleep_func(sleeptime) + self._cur_delay = min(self._cur_delay * self.backoff, + self.max_delay) diff --git a/yarn/src/main/python/task-starter/kazoo/security.py b/yarn/src/main/python/task-starter/kazoo/security.py new file mode 100644 index 00000000..014646d0 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/security.py @@ -0,0 +1,138 @@ +"""Kazoo Security""" +from base64 import b64encode +from collections import namedtuple +import hashlib + + +# Represents a Zookeeper ID and ACL object +Id = namedtuple('Id', 'scheme id') + + +class ACL(namedtuple('ACL', 'perms id')): + """An ACL for a Zookeeper Node + + An ACL object is created by using an :class:`Id` object along with + a :class:`Permissions` setting. For convenience, + :meth:`make_digest_acl` should be used to create an ACL object with + the desired scheme, id, and permissions. + + """ + @property + def acl_list(self): + perms = [] + if self.perms & Permissions.ALL == Permissions.ALL: + perms.append('ALL') + return perms + if self.perms & Permissions.READ == Permissions.READ: + perms.append('READ') + if self.perms & Permissions.WRITE == Permissions.WRITE: + perms.append('WRITE') + if self.perms & Permissions.CREATE == Permissions.CREATE: + perms.append('CREATE') + if self.perms & Permissions.DELETE == Permissions.DELETE: + perms.append('DELETE') + if self.perms & Permissions.ADMIN == Permissions.ADMIN: + perms.append('ADMIN') + return perms + + def __repr__(self): + return 'ACL(perms=%r, acl_list=%s, id=%r)' % ( + self.perms, self.acl_list, self.id) + + +class Permissions(object): + READ = 1 + WRITE = 2 + CREATE = 4 + DELETE = 8 + ADMIN = 16 + ALL = 31 + + +# Shortcuts for common Ids +ANYONE_ID_UNSAFE = Id('world', 'anyone') +AUTH_IDS = Id('auth', '') + +# Shortcuts for common ACLs +OPEN_ACL_UNSAFE = [ACL(Permissions.ALL, ANYONE_ID_UNSAFE)] +CREATOR_ALL_ACL = [ACL(Permissions.ALL, AUTH_IDS)] +READ_ACL_UNSAFE = [ACL(Permissions.READ, ANYONE_ID_UNSAFE)] + + +def make_digest_acl_credential(username, password): + """Create a SHA1 digest credential""" + credential = username.encode('utf-8') + b":" + password.encode('utf-8') + cred_hash = b64encode(hashlib.sha1(credential).digest()).strip() + return username + ":" + cred_hash.decode('utf-8') + + +def make_acl(scheme, credential, read=False, write=False, + create=False, delete=False, admin=False, all=False): + """Given a scheme and credential, return an :class:`ACL` object + appropriate for use with Kazoo. + + :param scheme: The scheme to use. I.e. `digest`. + :param credential: + A colon separated username, password. The password should be + hashed with the `scheme` specified. The + :meth:`make_digest_acl_credential` method will create and + return a credential appropriate for use with the `digest` + scheme. + :param write: Write permission. + :type write: bool + :param create: Create permission. + :type create: bool + :param delete: Delete permission. + :type delete: bool + :param admin: Admin permission. + :type admin: bool + :param all: All permissions. + :type all: bool + + :rtype: :class:`ACL` + + """ + if all: + permissions = Permissions.ALL + else: + permissions = 0 + if read: + permissions |= Permissions.READ + if write: + permissions |= Permissions.WRITE + if create: + permissions |= Permissions.CREATE + if delete: + permissions |= Permissions.DELETE + if admin: + permissions |= Permissions.ADMIN + return ACL(permissions, Id(scheme, credential)) + + +def make_digest_acl(username, password, read=False, write=False, + create=False, delete=False, admin=False, all=False): + """Create a digest ACL for Zookeeper with the given permissions + + This method combines :meth:`make_digest_acl_credential` and + :meth:`make_acl` to create an :class:`ACL` object appropriate for + use with Kazoo's ACL methods. + + :param username: Username to use for the ACL. + :param password: A plain-text password to hash. + :param write: Write permission. + :type write: bool + :param create: Create permission. + :type create: bool + :param delete: Delete permission. + :type delete: bool + :param admin: Admin permission. + :type admin: bool + :param all: All permissions. + :type all: bool + + :rtype: :class:`ACL` + + """ + cred = make_digest_acl_credential(username, password) + return make_acl("digest", cred, read=read, write=write, create=create, + delete=delete, admin=admin, all=all) diff --git a/yarn/src/main/python/task-starter/kazoo/testing/__init__.py b/yarn/src/main/python/task-starter/kazoo/testing/__init__.py new file mode 100644 index 00000000..bf8f149a --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/testing/__init__.py @@ -0,0 +1,5 @@ +from kazoo.testing.harness import KazooTestCase +from kazoo.testing.harness import KazooTestHarness + + +__all__ = ('KazooTestHarness', 'KazooTestCase', ) diff --git a/yarn/src/main/python/task-starter/kazoo/testing/common.py b/yarn/src/main/python/task-starter/kazoo/testing/common.py new file mode 100644 index 00000000..e0a34036 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/testing/common.py @@ -0,0 +1,308 @@ +# +# Copyright (C) 2010-2011, 2011 Canonical Ltd. All Rights Reserved +# +# This file was originally taken from txzookeeper and modified later. +# +# Authors: +# Kapil Thangavelu and the Kazoo team +# +# txzookeeper is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# txzookeeper is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with txzookeeper. If not, see . + + +import code +import logging +import os +import os.path +import shutil +import signal +import subprocess +import tempfile +import traceback + +from itertools import chain +from collections import namedtuple +from glob import glob + + +log = logging.getLogger(__name__) + + +def debug(sig, frame): + """Interrupt running process, and provide a python prompt for + interactive debugging.""" + d = {'_frame': frame} # Allow access to frame object. + d.update(frame.f_globals) # Unless shadowed by global + d.update(frame.f_locals) + + i = code.InteractiveConsole(d) + message = "Signal recieved : entering python shell.\nTraceback:\n" + message += ''.join(traceback.format_stack(frame)) + i.interact(message) + + +def listen(): + if os.name != 'nt': # SIGUSR1 is not supported on Windows + signal.signal(signal.SIGUSR1, debug) # Register handler +listen() + + +def to_java_compatible_path(path): + if os.name == 'nt': + path = path.replace('\\', '/') + return path + +ServerInfo = namedtuple( + "ServerInfo", + "server_id client_port election_port leader_port admin_port") + + +class ManagedZooKeeper(object): + """Class to manage the running of a ZooKeeper instance for testing. + + Note: no attempt is made to probe the ZooKeeper instance is + actually available, or that the selected port is free. In the + future, we may want to do that, especially when run in a + Hudson/Buildbot context, to ensure more test robustness.""" + + def __init__(self, software_path, server_info, peers=(), classpath=None): + """Define the ZooKeeper test instance. + + @param install_path: The path to the install for ZK + @param port: The port to run the managed ZK instance + """ + self.install_path = software_path + self._classpath = classpath + self.server_info = server_info + self.host = "127.0.0.1" + self.peers = peers + self.working_path = tempfile.mkdtemp() + self._running = False + + def run(self): + """Run the ZooKeeper instance under a temporary directory. + + Writes ZK log messages to zookeeper.log in the current directory. + """ + if self.running: + return + config_path = os.path.join(self.working_path, "zoo.cfg") + log_path = os.path.join(self.working_path, "log") + log4j_path = os.path.join(self.working_path, "log4j.properties") + data_path = os.path.join(self.working_path, "data") + + # various setup steps + if not os.path.exists(self.working_path): + os.mkdir(self.working_path) + if not os.path.exists(log_path): + os.mkdir(log_path) + if not os.path.exists(data_path): + os.mkdir(data_path) + + with open(config_path, "w") as config: + config.write(""" +tickTime=2000 +dataDir=%s +clientPort=%s +maxClientCnxns=0 +admin.serverPort=%s +""" % (to_java_compatible_path(data_path), + self.server_info.client_port, + self.server_info.admin_port)) # NOQA + + # setup a replicated setup if peers are specified + if self.peers: + servers_cfg = [] + for p in chain((self.server_info,), self.peers): + servers_cfg.append("server.%s=localhost:%s:%s" % ( + p.server_id, p.leader_port, p.election_port)) + + with open(config_path, "a") as config: + config.write(""" +initLimit=4 +syncLimit=2 +%s +""" % ("\n".join(servers_cfg))) + + # Write server ids into datadir + with open(os.path.join(data_path, "myid"), "w") as myid_file: + myid_file.write(str(self.server_info.server_id)) + + with open(log4j_path, "w") as log4j: + log4j.write(""" +# DEFAULT: console appender only +log4j.rootLogger=INFO, ROLLINGFILE +log4j.appender.ROLLINGFILE.layout=org.apache.log4j.PatternLayout +log4j.appender.ROLLINGFILE.layout.ConversionPattern=%d{ISO8601} [myid:%X{myid}] - %-5p [%t:%C{1}@%L] - %m%n +log4j.appender.ROLLINGFILE=org.apache.log4j.RollingFileAppender +log4j.appender.ROLLINGFILE.Threshold=DEBUG +log4j.appender.ROLLINGFILE.File=""" + to_java_compatible_path( # NOQA + self.working_path + os.sep + "zookeeper.log\n")) + + args = [ + "java", + "-cp", self.classpath, + + # "-Dlog4j.debug", + "-Dreadonlymode.enabled=true", + "-Dzookeeper.log.dir=%s" % log_path, + "-Dzookeeper.root.logger=INFO,CONSOLE", + "-Dlog4j.configuration=file:%s" % log4j_path, + + # OS X: Prevent java from appearing in menu bar, process dock + # and from activation of the main workspace on run. + "-Djava.awt.headless=true", + + "org.apache.zookeeper.server.quorum.QuorumPeerMain", + config_path, + ] + self.process = subprocess.Popen(args=args) + log.info("Started zookeeper process %s using args %s", + self.process.pid, args) + self._running = True + + @property + def classpath(self): + """Get the classpath necessary to run ZooKeeper.""" + + if self._classpath: + return self._classpath + + # Two possibilities, as seen in zkEnv.sh: + # Check for a release - top-level zookeeper-*.jar? + jars = glob((os.path.join( + self.install_path, 'zookeeper-*.jar'))) + if jars: + # Release build (`ant package`) + jars.extend(glob(os.path.join( + self.install_path, + "lib/*.jar"))) + # support for different file locations on Debian/Ubuntu + jars.extend(glob(os.path.join( + self.install_path, + "log4j-*.jar"))) + jars.extend(glob(os.path.join( + self.install_path, + "slf4j-api-*.jar"))) + jars.extend(glob(os.path.join( + self.install_path, + "slf4j-log4j-*.jar"))) + else: + # Development build (plain `ant`) + jars = glob((os.path.join( + self.install_path, 'build/zookeeper-*.jar'))) + jars.extend(glob(os.path.join( + self.install_path, + "build/lib/*.jar"))) + + return os.pathsep.join(jars) + + @property + def address(self): + """Get the address of the ZooKeeper instance.""" + return "%s:%s" % (self.host, self.client_port) + + @property + def running(self): + return self._running + + @property + def client_port(self): + return self.server_info.client_port + + def reset(self): + """Stop the zookeeper instance, cleaning out its on disk-data.""" + self.stop() + shutil.rmtree(os.path.join(self.working_path, "data")) + os.mkdir(os.path.join(self.working_path, "data")) + with open(os.path.join(self.working_path, "data", "myid"), "w") as fh: + fh.write(str(self.server_info.server_id)) + + def stop(self): + """Stop the Zookeeper instance, retaining on disk state.""" + if not self.running: + return + self.process.terminate() + self.process.wait() + if self.process.returncode != 0: + log.warn("Zookeeper process %s failed to terminate with" + " non-zero return code (it terminated with %s return" + " code instead)", self.process.pid, + self.process.returncode) + self._running = False + + def destroy(self): + """Stop the ZooKeeper instance and destroy its on disk-state""" + # called by at exit handler, reimport to avoid cleanup race. + import shutil + self.stop() + + shutil.rmtree(self.working_path) + + +class ZookeeperCluster(object): + + def __init__(self, install_path=None, classpath=None, + size=3, port_offset=20000): + self._install_path = install_path + self._classpath = classpath + self._servers = [] + + # Calculate ports and peer group + port = port_offset + peers = [] + + for i in range(size): + info = ServerInfo(i + 1, port, port + 1, port + 2, port + 3) + peers.append(info) + port += 10 + + # Instantiate Managed ZK Servers + for i in range(size): + server_peers = list(peers) + server_info = server_peers.pop(i) + self._servers.append( + ManagedZooKeeper( + self._install_path, server_info, server_peers, + classpath=self._classpath)) + + def __getitem__(self, k): + return self._servers[k] + + def __iter__(self): + return iter(self._servers) + + def start(self): + # Zookeeper client expresses a preference for either lower ports or + # lexicographical ordering of hosts, to ensure that all servers have a + # chance to startup, start them in reverse order. + for server in reversed(list(self)): + server.run() + # Giving the servers a moment to start, decreases the overall time + # required for a client to successfully connect (2s vs. 4s without + # the sleep). + import time + time.sleep(2) + + def stop(self): + for server in self: + server.stop() + self._servers = [] + + def terminate(self): + for server in self: + server.destroy() + + def reset(self): + for server in self: + server.reset() diff --git a/yarn/src/main/python/task-starter/kazoo/testing/harness.py b/yarn/src/main/python/task-starter/kazoo/testing/harness.py new file mode 100644 index 00000000..26fe2d26 --- /dev/null +++ b/yarn/src/main/python/task-starter/kazoo/testing/harness.py @@ -0,0 +1,165 @@ +"""Kazoo testing harnesses""" + +import logging +import os +import uuid +import unittest + +from kazoo import python2atexit as atexit + +from kazoo.client import KazooClient +from kazoo.exceptions import KazooException, NotEmptyError +from kazoo.protocol.states import ( + KazooState +) +from kazoo.testing.common import ZookeeperCluster +from kazoo.protocol.connection import _CONNECTION_DROP, _SESSION_EXPIRED + +log = logging.getLogger(__name__) + +CLUSTER = None + + +def get_global_cluster(): + global CLUSTER + if CLUSTER is None: + ZK_HOME = os.environ.get("ZOOKEEPER_PATH") + ZK_CLASSPATH = os.environ.get("ZOOKEEPER_CLASSPATH") + ZK_PORT_OFFSET = int(os.environ.get("ZOOKEEPER_PORT_OFFSET", 20000)) + + assert ZK_HOME or ZK_CLASSPATH, ( + "Either ZOOKEEPER_PATH or ZOOKEEPER_CLASSPATH environment " + "variable must be defined.\n" + "For deb package installations this is /usr/share/java") + + CLUSTER = ZookeeperCluster( + install_path=ZK_HOME, + classpath=ZK_CLASSPATH, + port_offset=ZK_PORT_OFFSET, + ) + atexit.register(lambda cluster: cluster.terminate(), CLUSTER) + return CLUSTER + + +class KazooTestHarness(unittest.TestCase): + """Harness for testing code that uses Kazoo + + This object can be used directly or as a mixin. It supports starting + and stopping a complete ZooKeeper cluster locally and provides an + API for simulating errors and expiring sessions. + + Example:: + + class MyTestCase(KazooTestHarness): + def setUp(self): + self.setup_zookeeper() + + # additional test setup + + def tearDown(self): + self.teardown_zookeeper() + + def test_something(self): + something_that_needs_a_kazoo_client(self.client) + + def test_something_else(self): + something_that_needs_zk_servers(self.servers) + + """ + + def __init__(self, *args, **kw): + super(KazooTestHarness, self).__init__(*args, **kw) + self.client = None + self._clients = [] + + @property + def cluster(self): + return get_global_cluster() + + @property + def servers(self): + return ",".join([s.address for s in self.cluster]) + + def _get_nonchroot_client(self): + c = KazooClient(self.servers) + self._clients.append(c) + return c + + def _get_client(self, **kwargs): + c = KazooClient(self.hosts, **kwargs) + self._clients.append(c) + return c + + def lose_connection(self, event_factory): + """Force client to lose connection with server""" + self.__break_connection(_CONNECTION_DROP, KazooState.SUSPENDED, event_factory) + + def expire_session(self, event_factory): + """Force ZK to expire a client session""" + self.__break_connection(_SESSION_EXPIRED, KazooState.LOST, event_factory) + + def setup_zookeeper(self, **client_options): + """Create a ZK cluster and chrooted :class:`KazooClient` + + The cluster will only be created on the first invocation and won't be + fully torn down until exit. + """ + do_start = False + for s in self.cluster: + if not s.running: + do_start = True + if do_start: + self.cluster.start() + namespace = "/kazootests" + uuid.uuid4().hex + self.hosts = self.servers + namespace + if 'timeout' not in client_options: + client_options['timeout'] = 0.8 + self.client = self._get_client(**client_options) + self.client.start() + self.client.ensure_path("/") + + def teardown_zookeeper(self): + """Reset and cleanup the zookeeper cluster that was started.""" + while self._clients: + c = self._clients.pop() + try: + c.stop() + except KazooException: + log.exception("Failed stopping client %s", c) + finally: + c.close() + self.client = None + + def __break_connection(self, break_event, expected_state, event_factory): + """Break ZooKeeper connection using the specified event.""" + + lost = event_factory() + safe = event_factory() + + def watch_loss(state): + if state == expected_state: + lost.set() + elif lost.is_set() and state == KazooState.CONNECTED: + safe.set() + return True + + self.client.add_listener(watch_loss) + self.client._call(break_event, None) + + lost.wait(5) + if not lost.isSet(): + raise Exception("Failed to get notified of broken connection.") + + safe.wait(15) + if not safe.isSet(): + raise Exception("Failed to see client reconnect.") + + self.client.retry(self.client.get_async, '/') + + +class KazooTestCase(KazooTestHarness): + def setUp(self): + self.setup_zookeeper() + + def tearDown(self): + self.teardown_zookeeper() diff --git a/yarn/src/main/python/task-starter/wrapper/__init__.py b/yarn/src/main/python/task-starter/wrapper/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/yarn/src/main/python/task-starter/wrapper/__main__.py b/yarn/src/main/python/task-starter/wrapper/__main__.py new file mode 100644 index 00000000..fbf9a7d3 --- /dev/null +++ b/yarn/src/main/python/task-starter/wrapper/__main__.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python +""" +A wrapper that launches a TensorFlow program. This is launched by the +Application master. Input (Using ENV vars): + +DTF_APPLICATION_ID +DTF_TASK_PROGRAM +DTF_TASK_JOB_NAME +DTF_TASK_INDEX +DTF_INPUT_PATH +DTF_OUTPUT_PATH +DTF_ZK_HOSTS +DTF_SERVICE_CLASS + +JAVA_HOME +HADOOP_HOME +CONTAINER_ID +""" +from __future__ import print_function + +import atexit +import json +import logging +import os +import signal +import socket +import subprocess +import sys +from threading import Event + +from kazoo.client import KazooClient + +DTF_APPLICATION_ID = "DTF_APPLICATION_ID" +DTF_SERVICE_CLASS = "DTF_SERVICE_CLASS" +DTF_ZK_HOSTS = "DTF_ZK_HOSTS" + +DTF_TASK_PROGRAM = "DTF_TASK_PROGRAM" +DTF_INPUT_PATH = "DTF_INPUT_PATH" +DTF_OUTPUT_PATH = "DTF_OUTPUT_PATH" +DTF_TASK_JOB_NAME = "DTF_TASK_JOB_NAME" +DTF_TASK_INDEX = "DTF_TASK_INDEX" +DTF_DOCKER_IMAGE = "DTF_DOCKER_IMAGE" + +JAVA_HOME = "JAVA_HOME" +HADOOP_HOME = "HADOOP_HOME" +USER = "USER" +CONTAINER_ID = "CONTAINER_ID" + +AM_PATH = "/registry/users/{user}/{service_class}/{app_id}" +CN_PATH = "/components/{container_id}" + +LAUNCH_CMD = """ +source $HADOOP_HOME/libexec/hadoop-config.sh ; +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$JAVA_HOME/jre/lib/amd64/server ; +export CLASSPATH=$($HADOOP_HDFS_HOME/bin/hdfs classpath --glob) ; {task_program} +""" + +DOCKER_ENGINE_PREFIX = """ +MAPPING="" ; +if [[ "${{DTF_INPUT_PATH}}" != "" && "${{DTF_INPUT_PATH}}" != hdfs://* ]]; then + MAPPING=${{MAPPING}}" -v ${{DTF_INPUT_PATH}}:${{DTF_INPUT_PATH}}" +fi ; +if [[ "${{DTF_OUTPUT_PATH}}" != "" && "${{DTF_OUTPUT_PATH}}" != hdfs://* ]]; then + MAPPING=${{MAPPING}}" -v ${{DTF_OUTPUT_PATH}}:${{DTF_OUTPUT_PATH}}" +fi ; +/usr/bin/docker run --rm -u $(id -u $USER):$(id -g $USER) --net=host --name=${{CONTAINER_ID}} \ +-v /etc/passwd:/etc/passwd -v /etc/group:/etc/group -v ${{LOCAL_DIRS}}:${{LOCAL_DIRS}} -v ${{LOG_DIRS}}:${{LOG_DIRS}} \ +-v ${{HADOOP_HOME}}:${{HADOOP_HOME}} -v ${{HADOOP_CONF_DIR}}:${{HADOOP_CONF_DIR}} \ +-v ${{JAVA_HOME}}:${{JAVA_HOME}} ${{MAPPING}} -e JAVA_HOME -e HADOOP_HOME -e LD_LIBRARY_PATH \ +-e CLASSPATH -e HADOOP_CONF_DIR -e HADOOP_HDFS_HOME -e LOCAL_DIRS -e LOG_DIRS -e CONTAINER_ID \ +{dtf_vars} {docker_image} bash -c 'cd ${{LOCAL_DIRS}}/${{CONTAINER_ID}} ; {task_program}' +""" + +SLEEP_TIME = 1 +LOG_FORMAT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +LOG = logging.getLogger('wrapper') + +def generate_docker_cmd(docker_image, task_program): + "Generate docker command" + dtf_vars = "" + for key in os.environ: + if key.startswith('DTF_'): + dtf_vars += " -e %s" % key + + return DOCKER_ENGINE_PREFIX.format( + dtf_vars=dtf_vars, + docker_image=docker_image, + task_program=task_program) + +def launch_prog(opts): + """ + Launches the main program + """ + task_program = 'source %s' % opts[DTF_TASK_PROGRAM] + docker_image = os.getenv(DTF_DOCKER_IMAGE) + if docker_image is None: + cmd = LAUNCH_CMD.format(task_program=task_program) + else: + cmd = LAUNCH_CMD.format(task_program=generate_docker_cmd(docker_image, task_program)) + + LOG.info("Running %s", cmd) + proc = subprocess.Popen( + ['bash', '-c', cmd], stdout=sys.stdout.fileno(), + stderr=sys.stderr.fileno()) + + proc.communicate() + + LOG.info('Task program return code=' + str(proc.returncode)) + + return proc.returncode + +def get_socket(ipaddr='', port=0): + "Reserve a port and returns port and the socket" + try: + sock = socket.socket() + sock.bind((ipaddr, port)) + port = sock.getsockname()[1] + return port, sock + except socket.error as err: + LOG.error("Encountered error while trying to open socket - " + str(err)) + raise err + +def get_envs(): + """ + Retrives input data frome env vars + """ + opts = {} + keys = [ + DTF_APPLICATION_ID, DTF_ZK_HOSTS, DTF_SERVICE_CLASS, + DTF_TASK_PROGRAM, DTF_TASK_JOB_NAME, DTF_TASK_INDEX, + HADOOP_HOME, JAVA_HOME, USER, CONTAINER_ID] + for key in keys: + env = os.getenv(key) + if env is None: + LOG.error("%s env var not found", key) + raise KeyError(key) + opts[key] = env + return opts + +def zookeeper_get_spec(opts, port_data): + """ + Uses zookeeper to register the port, and get cluster_spec + """ + service_class = opts[DTF_SERVICE_CLASS] + app_id = opts[DTF_APPLICATION_ID] + container_id = opts[CONTAINER_ID] + user = opts[USER] + zk_host = opts[DTF_ZK_HOSTS] + zk_client = KazooClient(hosts=zk_host) + zk_client.start() + + am_path = AM_PATH.format(service_class=service_class, app_id=app_id, user=user) + cn_path = am_path + CN_PATH.format(container_id=container_id) + zk_client.ensure_path(cn_path) + + port_data = port_data.encode('ascii') + + zk_client.set(cn_path, port_data) + + LOG.info("Wating for cluster spec") + + event = Event() + data, _ = zk_client.get(am_path, watch=lambda _: event.set()) + if data == b'': + event.wait() + + data, _ = zk_client.get(am_path) + + data = data.decode('ascii') + spec = json.loads(data) + return spec + +def main(): + """ + Main function + """ + + if sys.argv[-1] == "--debug": + logging.basicConfig(format=LOG_FORMAT, level=logging.DEBUG) + else: + logging.basicConfig(format=LOG_FORMAT, level=logging.INFO) + + opts = get_envs() + LOG.debug(os.environ) + + port, sock = get_socket() + + LOG.debug("Reserverd port number %d", port) + + port_data = json.dumps({ + "type": "JSONServiceRecord", + "description": "YARN Distributed TensorFlow Container", + "external": [], + "internal": [], + "yarn:persistence": "container", + "yarn:id": opts[CONTAINER_ID], + "task_job_name": opts[DTF_TASK_JOB_NAME], + "task_job_index": opts[DTF_TASK_INDEX], + "task_port": str(port), + }) + + cluster_spec = zookeeper_get_spec(opts, port_data) + + LOG.debug("Spec %s", cluster_spec) + for key in cluster_spec.keys(): + if key.startswith('DTF_'): + os.environ[key] = cluster_spec[key] + sock.close() + return launch_prog(opts) + +def exit_handler(signum, _): + """ + Capture exit signal + """ + LOG.info('Killed by signal %d', signum) + sys.exit(0) + +def kill_docker(): + "Cleanup docker container" + docker_image = os.getenv(DTF_DOCKER_IMAGE) + container_id = os.getenv(CONTAINER_ID) + if container_id and docker_image: + LOG.info('Killing docker image %s with id %s', docker_image, container_id) + kill_cmd = "docker ps | awk '/{container_id}/ {{print $1}}' | xargs -r docker kill" + proc = subprocess.Popen(["bash", "-c", kill_cmd.format(container_id=container_id)]) + proc.communicate() + +if __name__ == '__main__': + atexit.register(kill_docker) + signal.signal(signal.SIGTERM, exit_handler) + sys.exit(main()) diff --git a/yarn/src/test/java/org/tensorflow/hadoop/yarn/EmbeddedZKServer.java b/yarn/src/test/java/org/tensorflow/hadoop/yarn/EmbeddedZKServer.java new file mode 100644 index 00000000..069c743c --- /dev/null +++ b/yarn/src/test/java/org/tensorflow/hadoop/yarn/EmbeddedZKServer.java @@ -0,0 +1,67 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.hadoop.yarn; + +/** + * TensorFlow launcher for YARN + */ + +import java.io.File; +import java.io.IOException; +import java.net.BindException; +import java.net.InetSocketAddress; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.zookeeper.server.NIOServerCnxnFactory; +import org.apache.zookeeper.server.ZooKeeperServer; + +class EmbeddedZKServer { + + private static final Log LOG = LogFactory.getLog(EmbeddedZKServer.class); + private NIOServerCnxnFactory zkFactory; + private int zkport; + + void start() throws IOException, InterruptedException { + LOG.info("Starting up embedded Zookeeper server"); + File localfile = new File("./target/zookeeper.data"); + ZooKeeperServer zkServer; + zkServer = new ZooKeeperServer(localfile, localfile, 2000); + for (zkport = 60000; true; zkport++) + try { + zkFactory = new NIOServerCnxnFactory(); + zkFactory.configure(new InetSocketAddress(zkport), 2000); + break; + } catch (BindException e) { + if (zkport == 65535) throw new IOException("Fail to find a port for Zookeeper server to bind"); + } + LOG.info("Zookeeper port allocated:"+zkport); + zkFactory.startup(zkServer); + } + + int port() { + return zkport; + } + + void stop() { + LOG.info("shutdown embedded zookeeper server with port "+zkport); + zkFactory.shutdown(); + zkFactory = null; + } +} diff --git a/yarn/src/test/java/org/tensorflow/hadoop/yarn/TestTF.java b/yarn/src/test/java/org/tensorflow/hadoop/yarn/TestTF.java new file mode 100644 index 00000000..7b9d5f40 --- /dev/null +++ b/yarn/src/test/java/org/tensorflow/hadoop/yarn/TestTF.java @@ -0,0 +1,408 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tensorflow.hadoop.yarn; + +/** + * TensorFlow launcher for YARN + */ + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.registry.client.api.RegistryConstants; +import org.apache.hadoop.util.JarFinder; +import org.apache.hadoop.yarn.api.records.ApplicationReport; +import org.apache.hadoop.yarn.api.records.FinalApplicationStatus; +import org.apache.hadoop.yarn.api.records.YarnApplicationState; +import org.apache.hadoop.yarn.api.records.timeline.TimelineEntities; +import org.apache.hadoop.yarn.api.records.timeline.TimelineEntity; +import org.apache.hadoop.yarn.api.records.timeline.TimelineEvent; +import org.apache.hadoop.yarn.client.api.YarnClient; +import org.apache.hadoop.yarn.conf.YarnConfiguration; +import org.apache.hadoop.yarn.server.MiniYARNCluster; +import org.apache.hadoop.yarn.server.resourcemanager.scheduler.capacity.CapacityScheduler; +import org.junit.*; + +import java.io.*; +import java.net.URL; +import java.util.List; + +import static java.lang.Thread.sleep; + +public class TestTF { + private static final Log LOG = LogFactory.getLog(TestTF.class); + + private static MiniYARNCluster yarnCluster = null; + private static YarnConfiguration conf = null; + + // Internal ZooKeeper instance for integration test run + private static EmbeddedZKServer zkServer; + + private final static String APPMASTER_JAR = JarFinder.getJar(ApplicationMaster.class); + private String envCheckShell; + + private static synchronized File createYarnSiteConfig(Configuration yarn_conf) throws IOException { + URL url = Thread.currentThread().getContextClassLoader().getResource("yarn-site.xml"); + if (url == null) { + throw new RuntimeException("Could not find 'yarn-site.xml' dummy file in classpath"); + } + File yarnSiteXml = new File(url.getPath()); + FileWriter writer = new FileWriter(yarnSiteXml); + yarn_conf.set("yarn.application.classpath", yarnSiteXml.getParent()); + yarn_conf.setInt("yarn.nodemanager.delete.debug-delay-sec", 600); + yarn_conf.writeXml(writer); + writer.flush(); + writer.close(); + return yarnSiteXml; + } + + @Before + public void setup() throws Exception { + LOG.info("Starting up YARN cluster"); + + if (zkServer == null) { + zkServer = new EmbeddedZKServer(); + zkServer.start(); + } + + conf = new YarnConfiguration(); + conf.setInt(YarnConfiguration.RM_SCHEDULER_MINIMUM_ALLOCATION_MB, 1024); + conf.set("yarn.log.dir", "target"); + conf.setBoolean(YarnConfiguration.TIMELINE_SERVICE_ENABLED, true); + conf.set(RegistryConstants.KEY_REGISTRY_ZK_QUORUM, "localhost:" + zkServer.port()); + conf.set(YarnConfiguration.RM_SCHEDULER, CapacityScheduler.class.getName()); + + if (yarnCluster == null) { + yarnCluster = new MiniYARNCluster( + TestTF.class.getSimpleName(), 1, 1, 1, 1); + yarnCluster.init(conf); + + yarnCluster.start(); + conf.set(YarnConfiguration.TIMELINE_SERVICE_WEBAPP_ADDRESS, + MiniYARNCluster.getHostname() + ":" + + yarnCluster.getApplicationHistoryServer().getPort()); + } + sleep(2000); + + Configuration miniyarn_conf = yarnCluster.getConfig(); + createYarnSiteConfig(miniyarn_conf); + + URL url = Thread.currentThread().getContextClassLoader().getResource("env_check.sh"); + if (url == null) { + throw new RuntimeException("Could not find 'env_check.sh' file in resources"); + } + envCheckShell = url.getPath(); + + } + + @After + public void tearDown() throws IOException { + if (yarnCluster != null) { + LOG.info("shutdown MiniYarn cluster"); + try { + yarnCluster.stop(); + } finally { + yarnCluster = null; + } + } + + //shutdown Zookeeper server + if (zkServer != null) { + LOG.info("shutdown zookeeper"); + zkServer.stop(); + zkServer = null; + } + } + + /* + * Launches the client, waits for the app to finish and returns the report + */ + private FinalApplicationStatus getApplicationReport(String[] args) throws Exception { + final Client client = new Client(new Configuration(yarnCluster.getConfig())); + LOG.info("Initializing YARN TensorFlow Client"); + boolean initSuccess = client.init(args); + Assert.assertTrue(initSuccess); + LOG.info("Running YARN TensorFlow Client"); + boolean result = client.run(); + Assert.assertTrue(result); + LOG.info("Client run completed."); + YarnClient yarnClient = YarnClient.createYarnClient(); + yarnClient.init(new Configuration(yarnCluster.getConfig())); + yarnClient.start(); + ApplicationReport appReport = null; + boolean finished = false; + YarnApplicationState state = YarnApplicationState.NEW; + while (!finished) { + List apps = yarnClient.getApplications(); + if (apps.size() == 0) { + sleep(10); + continue; + } + appReport = apps.get(0); + if (appReport.getHost().equals("N/A")) { + sleep(10); + continue; + } + state = appReport.getYarnApplicationState(); + if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { + finished = true; + } + } + Assert.assertNotNull(appReport); + Assert.assertEquals(state, YarnApplicationState.FINISHED); + return appReport.getFinalApplicationStatus(); + } + + /* + * Launching 3 containers, 1 ps and 2 workers, all successful + */ + @Test(timeout = 90000) + public void testPositive() throws Exception { + String[] args = { + "--jar", + APPMASTER_JAR, + "-container_vcores", + "1", + "-container_memory", + "1024", + "-num_containers", + "ps:1,worker:2", + "-input_path", ".", + "-output_path", ".", + "-task_script", envCheckShell, + "-task_cmd", "sh ${DTF_TASK_SCRIPT}", + "-task_args", "test" + }; + + FinalApplicationStatus status = getApplicationReport(args); + Assert.assertEquals(FinalApplicationStatus.SUCCEEDED, status); + + TimelineEntities entitiesAttempts = yarnCluster + .getApplicationHistoryServer() + .getTimelineStore() + .getEntities(ApplicationMaster.TFEntity.TF_APP_ATTEMPT.toString(), + null, null, null, null, null, null, null, null, null); + Assert.assertNotNull(entitiesAttempts); + Assert.assertEquals(1, entitiesAttempts.getEntities().size()); + Assert.assertEquals(2, entitiesAttempts.getEntities().get(0).getEvents() + .size()); + Assert.assertEquals(entitiesAttempts.getEntities().get(0).getEntityType() + , ApplicationMaster.TFEntity.TF_APP_ATTEMPT.toString()); + TimelineEntities entities = yarnCluster + .getApplicationHistoryServer() + .getTimelineStore() + .getEntities(ApplicationMaster.TFEntity.TF_CONTAINER.toString(), null, + null, null, null, null, null, null, null, null); + Assert.assertNotNull(entities); + // "ps:1,worker:2" = 3 containers + Assert.assertEquals(3, entities.getEntities().size()); + Assert.assertEquals(entities.getEntities().get(0).getEntityType(), ApplicationMaster.TFEntity.TF_CONTAINER.toString()); + for (TimelineEntity entity: entities.getEntities()) { + // There are TF_CONTAINER_START and TF_CONTAINER_END events + for (TimelineEvent event: entity.getEvents()) { + if (event.getEventType().equals(ApplicationMaster.TFEvent.TF_CONTAINER_END.toString())) { + int exitStatus = (Integer) event.getEventInfo().get(ApplicationMaster.TFInfo.TF_EXIT_STATUS.toString()); + Assert.assertEquals(0, exitStatus); + } + } + } + } + + + /* + * Launching 3 containers, 1 ps and 2 workers, all fails + */ + @Test(timeout = 90000) + public void testFailAll() throws Exception { + String[] args = { + "--jar", + APPMASTER_JAR, + "-container_vcores", + "1", + "-container_memory", + "1024", + "-num_containers", + "ps:1,worker:2", + "-input_path", ".", + "-output_path", ".", + "-task_script", envCheckShell, + "-task_cmd", "sh ${DTF_TASK_SCRIPT} && exit 1", + "-task_args", "" + }; + + FinalApplicationStatus status = getApplicationReport(args); + + Assert.assertEquals(FinalApplicationStatus.FAILED, status); + + TimelineEntities entitiesAttempts = yarnCluster + .getApplicationHistoryServer() + .getTimelineStore() + .getEntities(ApplicationMaster.TFEntity.TF_APP_ATTEMPT.toString(), + null, null, null, null, null, null, null, null, null); + Assert.assertNotNull(entitiesAttempts); + Assert.assertEquals(1, entitiesAttempts.getEntities().size()); + Assert.assertEquals(2, entitiesAttempts.getEntities().get(0).getEvents() + .size()); + Assert.assertEquals(entitiesAttempts.getEntities().get(0).getEntityType() + , ApplicationMaster.TFEntity.TF_APP_ATTEMPT.toString()); + + TimelineEntities entities = yarnCluster + .getApplicationHistoryServer() + .getTimelineStore() + .getEntities(ApplicationMaster.TFEntity.TF_CONTAINER.toString(), null, + null, null, null, null, null, null, null, null); + Assert.assertNotNull(entities); + // "ps:1,worker:2" = 3 containers + Assert.assertEquals(3, entities.getEntities().size()); + Assert.assertEquals(entities.getEntities().get(0).getEntityType(), ApplicationMaster.TFEntity.TF_CONTAINER.toString()); + for (TimelineEntity entity: entities.getEntities()) { + // There are TF_CONTAINER_START and TF_CONTAINER_END events + for (TimelineEvent event: entity.getEvents()) { + if (event.getEventType().equals(ApplicationMaster.TFEvent.TF_CONTAINER_END.toString())) { + int exitStatus = (Integer) event.getEventInfo().get(ApplicationMaster.TFInfo.TF_EXIT_STATUS.toString()); + Assert.assertEquals(1, exitStatus); + } + } + } + } + + /* + * Launching 3 containers, 1 ps and 2 workers, fail the chief worker (task index 0) + */ + @Test(timeout = 90000) + public void testFailChiefWorker() throws Exception { + String[] args = { + "--jar", + APPMASTER_JAR, + "-container_vcores", + "1", + "-container_memory", + "1024", + "-num_containers", + "ps:1,worker:2", + "-input_path", ".", + "-output_path", ".", + "-task_script", envCheckShell, + "-task_cmd", "sh ${DTF_TASK_SCRIPT} && if [ \"${DTF_TASK_JOB_NAME}\" == 'worker' -a \"${DTF_TASK_INDEX}\" == '0' ]; then exit 1; fi", + "-task_args", "" + }; + + FinalApplicationStatus status = getApplicationReport(args); + + Assert.assertEquals(FinalApplicationStatus.FAILED, status); + + TimelineEntities entitiesAttempts = yarnCluster + .getApplicationHistoryServer() + .getTimelineStore() + .getEntities(ApplicationMaster.TFEntity.TF_APP_ATTEMPT.toString(), + null, null, null, null, null, null, null, null, null); + Assert.assertNotNull(entitiesAttempts); + Assert.assertEquals(1, entitiesAttempts.getEntities().size()); + Assert.assertEquals(2, entitiesAttempts.getEntities().get(0).getEvents() + .size()); + Assert.assertEquals(entitiesAttempts.getEntities().get(0).getEntityType() + , ApplicationMaster.TFEntity.TF_APP_ATTEMPT.toString()); + + TimelineEntities entities = yarnCluster + .getApplicationHistoryServer() + .getTimelineStore() + .getEntities(ApplicationMaster.TFEntity.TF_CONTAINER.toString(), null, + null, null, null, null, null, null, null, null); + Assert.assertNotNull(entities); + // "ps:1,worker:2" = 3 containers + Assert.assertEquals(3, entities.getEntities().size()); + Assert.assertEquals(entities.getEntities().get(0).getEntityType(), ApplicationMaster.TFEntity.TF_CONTAINER.toString()); + for (TimelineEntity entity: entities.getEntities()) { + // There are TF_CONTAINER_START and TF_CONTAINER_END events + for (TimelineEvent event: entity.getEvents()) { + if (event.getEventType().equals(ApplicationMaster.TFEvent.TF_CONTAINER_END.toString())) { + int exitStatus = (Integer) event.getEventInfo().get(ApplicationMaster.TFInfo.TF_EXIT_STATUS.toString()); + String taskName = (String) event.getEventInfo().get(ApplicationMaster.TFInfo.TF_TASK_NAME.toString()); + if (taskName.equals("worker[0]")) { + Assert.assertEquals(1, exitStatus); + } else { + Assert.assertEquals(0, exitStatus); + } + } + } + } + } + + /* + * Launching 3 containers, 1 ps and 2 workers, fail the non-chief worker (task index 1) + */ + @Test(timeout = 90000) + public void testFailNonChiefWorker() throws Exception { + String[] args = { + "--jar", + APPMASTER_JAR, + "-container_vcores", + "1", + "-container_memory", + "1024", + "-num_containers", + "ps:1,worker:2", + "-input_path", ".", + "-output_path", ".", + "-task_script", envCheckShell, + "-task_cmd", "sh ${DTF_TASK_SCRIPT} && if [ ${DTF_TASK_JOB_NAME} == 'worker' -a ${DTF_TASK_INDEX} == '1' ]; then exit 1; fi", + "-task_args", "" + }; + + FinalApplicationStatus status = getApplicationReport(args); + + Assert.assertEquals(FinalApplicationStatus.FAILED, status); + + TimelineEntities entitiesAttempts = yarnCluster + .getApplicationHistoryServer() + .getTimelineStore() + .getEntities(ApplicationMaster.TFEntity.TF_APP_ATTEMPT.toString(), + null, null, null, null, null, null, null, null, null); + Assert.assertNotNull(entitiesAttempts); + Assert.assertEquals(1, entitiesAttempts.getEntities().size()); + Assert.assertEquals(2, entitiesAttempts.getEntities().get(0).getEvents() + .size()); + Assert.assertEquals(entitiesAttempts.getEntities().get(0).getEntityType() + , ApplicationMaster.TFEntity.TF_APP_ATTEMPT.toString()); + + TimelineEntities entities = yarnCluster + .getApplicationHistoryServer() + .getTimelineStore() + .getEntities(ApplicationMaster.TFEntity.TF_CONTAINER.toString(), null, + null, null, null, null, null, null, null, null); + Assert.assertNotNull(entities); + // "ps:1,worker:2" = 3 containers + Assert.assertEquals(3, entities.getEntities().size()); + Assert.assertEquals(entities.getEntities().get(0).getEntityType(), ApplicationMaster.TFEntity.TF_CONTAINER.toString()); + for (TimelineEntity entity: entities.getEntities()) { + // There are TF_CONTAINER_START and TF_CONTAINER_END events + for (TimelineEvent event: entity.getEvents()) { + if (event.getEventType().equals(ApplicationMaster.TFEvent.TF_CONTAINER_END.toString())) { + int exitStatus = (Integer) event.getEventInfo().get(ApplicationMaster.TFInfo.TF_EXIT_STATUS.toString()); + String taskName = (String) event.getEventInfo().get(ApplicationMaster.TFInfo.TF_TASK_NAME.toString()); + if (taskName.equals("worker[1]")) { + Assert.assertEquals(1, exitStatus); + } else { + Assert.assertEquals(0, exitStatus); + } + } + } + } + } +} diff --git a/yarn/src/test/resources/env_check.sh b/yarn/src/test/resources/env_check.sh new file mode 100644 index 00000000..5348cc6e --- /dev/null +++ b/yarn/src/test/resources/env_check.sh @@ -0,0 +1,11 @@ +#!/bin/sh +echo DTF_SERVICE_CLASS: "${DTF_SERVICE_CLASS:?DTF_SERVICE_CLASS is not set}" +echo DTF_TASK_SCRIPT: "${DTF_TASK_SCRIPT:?DTF_TASK_SCRIPT is not set}" +echo DTF_APPLICATION_ID: "${DTF_APPLICATION_ID:?DTF_APPLICATION_ID is not set}" +echo DTF_TASK_INDEX: "${DTF_TASK_INDEX:?DTF_TASK_INDEX is not set}" +echo DTF_ZK_HOSTS: "${DTF_ZK_HOSTS:?Service is not set}" +echo DTF_INPUT_PATH: "${DTF_INPUT_PATH:?Service is not set}" +echo DTF_OUTPUT_PATH: "${DTF_OUTPUT_PATH:?DTF_OUTPUT_PATH is not set}" +echo DTF_WORKER_HOSTS: "${DTF_WORKER_HOSTS:?DTF_WORKER_HOSTS is not set}" +echo DTF_PS_HOSTS: "${DTF_PS_HOSTS:?DTF_PS_HOSTS is not set}" +echo DTF_TASK_JOB_NAME: "${DTF_TASK_JOB_NAME:?DTF_TASK_JOB_NAME is not set}" diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties new file mode 100644 index 00000000..531b68b5 --- /dev/null +++ b/yarn/src/test/resources/log4j.properties @@ -0,0 +1,19 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# log4j configuration used during build and unit tests + +log4j.rootLogger=info,stdout +log4j.threshhold=ALL +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{ISO8601} %-5p [%t] %c{2} (%F:%M(%L)) - %m%n diff --git a/yarn/src/test/resources/test.sh b/yarn/src/test/resources/test.sh new file mode 100644 index 00000000..3a0436c2 --- /dev/null +++ b/yarn/src/test/resources/test.sh @@ -0,0 +1,38 @@ +#!/usr/bin/bash +set -u -e +export USER=test_user +export HADOOP_HOME=~/hadoop +export CONTAINER_ID=234 +export DTF_ZK_HOSTS=localhost:2181 +export DTF_SERVICE_CLASS=yarn-dtf +export DTF_APPLICATION_ID=123 +export DTF_TASK_PROGRAM="test_task.sh" +export DTF_TASK_JOB_NAME=ps +export DTF_TASK_INDEX=0 +export DTF_INPUT_PATH=. +export DTF_OUTPUT_PATH=. + +( +python - < + + + + + + + + diff --git a/yarn/task-starter-assembly.xml b/yarn/task-starter-assembly.xml new file mode 100644 index 00000000..832ec7b0 --- /dev/null +++ b/yarn/task-starter-assembly.xml @@ -0,0 +1,31 @@ + + + task-starter + + zip + + false + + + src/main/python/task-starter + + + **/*.pyc + + + +