diff --git a/cdmt.cu b/cdmt.cu index 3f43fef..dc8a248 100644 --- a/cdmt.cu +++ b/cdmt.cu @@ -66,6 +66,7 @@ int main(int argc,char *argv[]) unsigned char *cbuf,*dcbuf; float *fbuf,*dfbuf; float *bs1,*bs2,*zavg,*zstd; + void* cufftworkarea; cufftComplex *cp1,*cp2,*dc,*cp1p,*cp2p; cufftHandle ftc2cf,ftc2cb; int idist,odist,iembed,oembed,istride,ostride; @@ -188,13 +189,33 @@ int main(int argc,char *argv[]) checkCudaErrors(cudaMalloc((void **) &ddm,sizeof(float)*ndm)); checkCudaErrors(cudaMemcpy(ddm,dm,sizeof(float)*ndm,cudaMemcpyHostToDevice)); + + // Disable initial memory allocates and silence the compiler warnings; + // Nvidia uses a custom compiler frontend so GCC pragmas do not work. + // This order-of-execution follows Nvidia's usage guidance + #pragma diag_suppress used_before_set + cufftSetAutoAllocation(ftc2cf, 0); + cufftSetAutoAllocation(ftc2cb, 0); + #pragma diag_default used_before_set + size_t cfSize, cbSize; + // Generate FFT plan (batch in-place forward FFT) idist=nbin; odist=nbin; iembed=nbin; oembed=nbin; istride=1; ostride=1; checkCudaErrors(cufftPlanMany(&ftc2cf,1,&nbin,&iembed,istride,idist,&oembed,ostride,odist,CUFFT_C2C,nfft*nsub)); + checkCudaErrors(cufftGetSizeMany(ftc2cf,1,&nbin,&iembed,istride,idist,&oembed,ostride,odist,CUFFT_C2C,nfft*nsub, &cfSize)); // Generate FFT plan (batch in-place backward FFT) idist=mbin; odist=mbin; iembed=mbin; oembed=mbin; istride=1; ostride=1; checkCudaErrors(cufftPlanMany(&ftc2cb,1,&mbin,&iembed,istride,idist,&oembed,ostride,odist,CUFFT_C2C,nchan*nfft*nsub)); + checkCudaErrors(cufftGetSizeMany(ftc2cb,1,&mbin,&iembed,istride,idist,&oembed,ostride,odist,CUFFT_C2C,nchan*nfft*nsub,&cbSize)); + + + // Allocate the maxmimum memory needed + size_t minfftsize = cfSize > cbSize ? cfSize : cbSize; + checkCudaErrors(cudaMalloc((void**) &cufftworkarea, (size_t) minfftsize)); + // Set the cuFFT handles to use this area + cufftSetWorkArea(ftc2cf, cufftworkarea); + cufftSetWorkArea(ftc2cb, cufftworkarea); // Compute chirp blocksize.x=32; blocksize.y=32; blocksize.z=1; @@ -343,6 +364,7 @@ int main(int argc,char *argv[]) // Free plan cufftDestroy(ftc2cf); cufftDestroy(ftc2cb); + cudaFree(cufftworkarea); return 0; }