ebsynth代码源码分析

发布时间 2023-07-28 14:50:01作者: 白与花糖

最近接到了优化ebsynth性能的任务。对于一个新的算法,要优化它的性能,我觉得要从三步来分析

首当其冲的是要看懂代码,算法到底是怎么跑的,实现了什么功能,怎么实现的。

接着在不借助性能分析工具的情况下通读代码,把感觉像是性能瓶颈,算法内核的地方记录下来

再用性能分析优化工具分析出确切的位置,思考为什么这个地方会是瓶颈,怎么优化。

jamriska/ebsynth: Fast Example-based Image Synthesis and Style Transfer (github.com)

 

接下来进入正题。

 

ebsynth_cuda_memarray2.h文件

#ifndef EBSYNTH_CUDA_MEMARRAY2_H_
#define EBSYNTH_CUDA_MEMARRAY2_H_

#include "jzq.h"
#include "ebsynth_cuda_check.h"

template<typename T>
struct MemArray2
{
  T* data;
  int width;
  int height;

  MemArray2() : width(0),height(0),data(0) {};

  MemArray2(const V2i& size)
  {
    width = size(0);
    height = size(1);
    checkCudaError(cudaMalloc(&data,width*height*sizeof(T)));
  }

  MemArray2(int _width,int _height)
  {
    width = _width;
    height = _height;
    checkCudaError(cudaMalloc(&data,width*height*sizeof(T)));
  }
  /*
  int       __device__ operator()(int i,int j)
  {
    return data[i+j*width];
  }

  const int& __device__ operator()(int i,int j) const
  {
    return data[i+j*width];
  }
  */

  void destroy()
  {
    checkCudaError( cudaFree(data) );
  }
};

template<typename T>
void copy(MemArray2<T>* out_dst,const Array2<T>& src)
{
  assert(out_dst != 0);
  MemArray2<T>& dst = *out_dst;
  assert(dst.width == src.width());
  assert(dst.height == src.height());

  checkCudaError(cudaMemcpy(dst.data, src.data(), src.width()*src.height()*sizeof(T), cudaMemcpyHostToDevice));
}

template<typename T>
void copy(Array2<T>* out_dst,const MemArray2<T>& src)
{
  assert(out_dst != 0);
  const Array2<T>& dst = *out_dst;
  assert(dst.width() == src.width);
  assert(dst.height() == src.height);

  checkCudaError(cudaMemcpy((void*)dst.data(),src.data, src.width*src.height*sizeof(T), cudaMemcpyDeviceToHost));
}

#endif