diff --git a/nemo_run/core/execution/slurm.py b/nemo_run/core/execution/slurm.py index f483a61e..9d6f5dde 100644 --- a/nemo_run/core/execution/slurm.py +++ b/nemo_run/core/execution/slurm.py @@ -983,6 +983,10 @@ def get_container_flags( group_env_vars.append(current_env_vars) + _container_env = set(resource_req.container_env or []) + _container_env.update(full_env_vars.keys()) + _container_env.update(resource_req.env_vars.keys()) + _container_flags = get_container_flags( base_mounts=resource_req.container_mounts, src_job_dir=os.path.join( @@ -990,7 +994,7 @@ def get_container_flags( job_directory_name, ), container_image=resource_req.container_image, - container_env=resource_req.container_env, + container_env=sorted(_container_env), ) _srun_args = ["--wait=60", "--kill-on-bad-exit=1"] _srun_args.extend(resource_req.srun_args or []) @@ -999,6 +1003,10 @@ def get_container_flags( cmd_stderr = stderr_flags.copy() if cmd_stderr: cmd_stderr[-1] = cmd_stderr[-1].replace(original_job_name, self.jobs[group_ind]) + + _container_env = set(self.executor.container_env or []) + _container_env.update(full_env_vars.keys()) + _container_flags = get_container_flags( base_mounts=self.executor.container_mounts, src_job_dir=os.path.join( @@ -1006,7 +1014,7 @@ def get_container_flags( job_directory_name, ), container_image=self.executor.container_image, - container_env=self.executor.container_env, + container_env=sorted(_container_env), ) _srun_args = ["--wait=60", "--kill-on-bad-exit=1"] _srun_args.extend(self.executor.srun_args or [])