From 90eacc5b525195163527ce1f0211bcecbd4e642e Mon Sep 17 00:00:00 2001 From: Alex Filby Date: Fri, 21 Nov 2025 11:30:03 -0600 Subject: [PATCH] Set container_env for slurm sbatch when using env_vars Signed-off-by: Alex Filby --- nemo_run/core/execution/slurm.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/nemo_run/core/execution/slurm.py b/nemo_run/core/execution/slurm.py index f483a61e..9d6f5dde 100644 --- a/nemo_run/core/execution/slurm.py +++ b/nemo_run/core/execution/slurm.py @@ -983,6 +983,10 @@ def get_container_flags( group_env_vars.append(current_env_vars) + _container_env = set(resource_req.container_env or []) + _container_env.update(full_env_vars.keys()) + _container_env.update(resource_req.env_vars.keys()) + _container_flags = get_container_flags( base_mounts=resource_req.container_mounts, src_job_dir=os.path.join( @@ -990,7 +994,7 @@ def get_container_flags( job_directory_name, ), container_image=resource_req.container_image, - container_env=resource_req.container_env, + container_env=sorted(_container_env), ) _srun_args = ["--wait=60", "--kill-on-bad-exit=1"] _srun_args.extend(resource_req.srun_args or []) @@ -999,6 +1003,10 @@ def get_container_flags( cmd_stderr = stderr_flags.copy() if cmd_stderr: cmd_stderr[-1] = cmd_stderr[-1].replace(original_job_name, self.jobs[group_ind]) + + _container_env = set(self.executor.container_env or []) + _container_env.update(full_env_vars.keys()) + _container_flags = get_container_flags( base_mounts=self.executor.container_mounts, src_job_dir=os.path.join( @@ -1006,7 +1014,7 @@ def get_container_flags( job_directory_name, ), container_image=self.executor.container_image, - container_env=self.executor.container_env, + container_env=sorted(_container_env), ) _srun_args = ["--wait=60", "--kill-on-bad-exit=1"] _srun_args.extend(self.executor.srun_args or [])