Please note, this is a STATIC archive of website www.tutorialspoint.com from 11 May 2019, cach3.com does not collect or store any user information, there is no "phishing" involved.
Tutorialspoint

train_random

require 'torch'
require 'nn'
require 'optim'
require 'image'
util = paths.dofile('util.lua')

opt = {
   batchSize = 64,         -- number of samples to produce
   loadSize = 350,         -- resize the loaded image to loadsize maintaining aspect ratio. 0 means don't resize. -1 means scale randomly between [0.5,2] -- see donkey_folder.lua
   fineSize = 128,         -- size of random crops. Only 64 and 128 supported.
   nBottleneck = 100,      -- #  of dim for bottleneck of encoder
   nef = 64,               -- #  of encoder filters in first conv layer
   ngf = 64,               -- #  of gen filters in first conv layer
   ndf = 64,               -- #  of discrim filters in first conv layer
   nc = 3,                 -- # of channels in input
   wtl2 = 0,               -- 0 means don't use else use with this weight
   useOverlapPred = 0,        -- overlapping edges (1 means yes, 0 means no). 1 means put 10x more L2 weight on unmasked region.
   nThreads = 4,           -- #  of data loading threads to use
   niter = 25,             -- #  of iter at starting learning rate
   lr = 0.0002,            -- initial learning rate for adam
   beta1 = 0.5,            -- momentum term of adam
   ntrain = math.huge,     -- #  of examples per epoch. math.huge for full dataset
   display = 1,            -- display samples while training. 0 = false
   display_id = 10,        -- display window id.
   display_iter = 50,      -- # number of iterations after which display is updated
   gpu = 1,                -- gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
   name = 'train1',        -- name of the experiment you are running
   manualSeed = 0,         -- 0 means random seed

   -- Extra Options:
   conditionAdv = 0,       -- 0 means false else true
   noiseGen = 0,           -- 0 means false else true
   noisetype = 'normal',   -- uniform / normal
   nz = 100,               -- #  of dim for Z
}
for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end
print(opt)
if opt.display == 0 then opt.display = false end
if opt.conditionAdv == 0 then opt.conditionAdv = false end
if opt.noiseGen == 0 then opt.noiseGen = false end

-- set seed
if opt.manualSeed == 0 then
    opt.manualSeed = torch.random(1, 10000)
end
print("Seed: " .. opt.manualSeed)
torch.manualSeed(opt.manualSeed)
torch.setnumthreads(1)
torch.setdefaulttensortype('torch.FloatTensor')

-- create data loader
local DataLoader = paths.dofile('data/data.lua')
local data = DataLoader.new(opt.nThreads, opt)
print("Dataset Size: ", data:size())

---------------------------------------------------------------------------
-- Initialize network variables
---------------------------------------------------------------------------
local function weights_init(m)
   local name = torch.type(m)
   if name:find('Convolution') then
      m.weight:normal(0.0, 0.02)
      m.bias:fill(0)
   elseif name:find('BatchNormalization') then
      if m.weight then m.weight:normal(1.0, 0.02) end
      if m.bias then m.bias:fill(0) end
   end
end

local nc = opt.nc
local nz = opt.nz
local nBottleneck = opt.nBottleneck
local ndf = opt.ndf
local ngf = opt.ngf
local nef = opt.nef
local real_label = 1
local fake_label = 0

local SpatialBatchNormalization = nn.SpatialBatchNormalization
local SpatialConvolution = nn.SpatialConvolution
local SpatialFullConvolution = nn.SpatialFullConvolution

---------------------------------------------------------------------------
-- Generator net
---------------------------------------------------------------------------
-- Encode Input Context to noise (architecture similar to Discriminator)
local netE = nn.Sequential()
-- input is (nc) x 128 x 128
netE:add(SpatialConvolution(nc, nef, 4, 4, 2, 2, 1, 1))
netE:add(nn.LeakyReLU(0.2, true))
if opt.fineSize == 128 then
  -- state size: (nef) x 64 x 64
  netE:add(SpatialConvolution(nef, nef, 4, 4, 2, 2, 1, 1))
  netE:add(SpatialBatchNormalization(nef)):add(nn.LeakyReLU(0.2, true))
end
-- state size: (nef) x 32 x 32
netE:add(SpatialConvolution(nef, nef * 2, 4, 4, 2, 2, 1, 1))
netE:add(SpatialBatchNormalization(nef * 2)):add(nn.LeakyReLU(0.2, true))
-- state size: (nef*2) x 16 x 16
netE:add(SpatialConvolution(nef * 2, nef * 4, 4, 4, 2, 2, 1, 1))
netE:add(SpatialBatchNormalization(nef * 4)):add(nn.LeakyReLU(0.2, true))
-- state size: (nef*4) x 8 x 8
netE:add(SpatialConvolution(nef * 4, nef * 8, 4, 4, 2, 2, 1, 1))
netE:add(SpatialBatchNormalization(nef * 8)):add(nn.LeakyReLU(0.2, true))
-- state size: (nef*8) x 4 x 4
netE:add(SpatialConvolution(nef * 8, nBottleneck, 4, 4))
-- state size: (nBottleneck) x 1 x 1

local netG = nn.Sequential()
local nz_size = nBottleneck
if opt.noiseGen then
    local netG_noise = nn.Sequential()
    -- input is Z: (nz) x 1 x 1, going into a convolution
    netG_noise:add(SpatialConvolution(nz, nz, 1, 1, 1, 1, 0, 0))
    -- state size: (nz) x 1 x 1

    local netG_pl = nn.ParallelTable();
    netG_pl:add(netE)
    netG_pl:add(netG_noise)

    netG:add(netG_pl)
    netG:add(nn.JoinTable(2))
    netG:add(SpatialBatchNormalization(nBottleneck+nz)):add(nn.LeakyReLU(0.2, true))
    -- state size: (nBottleneck+nz) x 1 x 1

    nz_size = nBottleneck+nz
else
    netG:add(netE)
    netG:add(SpatialBatchNormalization(nBottleneck)):add(nn.LeakyReLU(0.2, true))

    nz_size = nBottleneck
end

-- Decode noise to generate image
-- input is Z: (nz_size) x 1 x 1, going into a convolution
netG:add(SpatialFullConvolution(nz_size, ngf * 8, 4, 4))
netG:add(SpatialBatchNormalization(ngf * 8)):add(nn.ReLU(true))
-- state size: (ngf*8) x 4 x 4
netG:add(SpatialFullConvolution(ngf * 8, ngf * 4, 4, 4, 2, 2, 1, 1))
netG:add(SpatialBatchNormalization(ngf * 4)):add(nn.ReLU(true))
-- state size: (ngf*4) x 8 x 8
netG:add(SpatialFullConvolution(ngf * 4, ngf * 2, 4, 4, 2, 2, 1, 1))
netG:add(SpatialBatchNormalization(ngf * 2)):add(nn.ReLU(true))
-- state size: (ngf*2) x 16 x 16
netG:add(SpatialFullConvolution(ngf * 2, ngf, 4, 4, 2, 2, 1, 1))
netG:add(SpatialBatchNormalization(ngf)):add(nn.ReLU(true))
-- state size: (ngf) x 32 x 32
if opt.fineSize == 128 then
  netG:add(SpatialFullConvolution(ngf, ngf, 4, 4, 2, 2, 1, 1))
  netG:add(SpatialBatchNormalization(ngf)):add(nn.ReLU(true))
  -- state size: (ngf) x 64 x 64
end
netG:add(SpatialFullConvolution(ngf, nc, 4, 4, 2, 2, 1, 1))
netG:add(nn.Tanh())
-- state size: (nc) x 128 x 128

netG:apply(weights_init)

---------------------------------------------------------------------------
-- Adversarial discriminator net
---------------------------------------------------------------------------
local netD = nn.Sequential()
if opt.conditionAdv then
    print('conditional adv not implemented')
    exit()
    local netD_ctx = nn.Sequential()
    -- input Context: (nc) x 128 x 128, going into a convolution
    netD_ctx:add(SpatialConvolution(nc, ndf, 5, 5, 2, 2, 2, 2))
    -- state size: (ndf) x 64 x 64

    local netD_pred = nn.Sequential()
    -- input pred: (nc) x 64 x 64, going into a convolution
    netD_pred:add(SpatialConvolution(nc, ndf, 5, 5, 2, 2, 2+32, 2+32))      -- 32: to keep scaling of features same as context
    -- state size: (ndf) x 64 x 64

    local netD_pl = nn.ParallelTable();
    netD_pl:add(netD_ctx)
    netD_pl:add(netD_pred)

    netD:add(netD_pl)
    netD:add(nn.JoinTable(2))
    netD:add(nn.LeakyReLU(0.2, true))
    -- state size: (ndf * 2) x 64 x 64

    netD:add(SpatialConvolution(ndf*2, ndf, 4, 4, 2, 2, 1, 1))
    netD:add(SpatialBatchNormalization(ndf)):add(nn.LeakyReLU(0.2, true))
    -- state size: (ndf) x 32 x 32
else
    -- input is (nc) x 128 x 128, going into a convolution
    netD:add(SpatialConvolution(nc, ndf, 4, 4, 2, 2, 1, 1))
    netD:add(nn.LeakyReLU(0.2, true))
    -- state size: (ndf) x 64 x 64
end
if opt.fineSize == 128 then
  netD:add(SpatialConvolution(ndf, ndf, 4, 4, 2, 2, 1, 1))
  netD:add(SpatialBatchNormalization(ndf)):add(nn.LeakyReLU(0.2, true))
  -- state size: (ndf) x 32 x 32
end
netD:add(SpatialConvolution(ndf, ndf * 2, 4, 4, 2, 2, 1, 1))
netD:add(SpatialBatchNormalization(ndf * 2)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*2) x 16 x 16
netD:add(SpatialConvolution(ndf * 2, ndf * 4, 4, 4, 2, 2, 1, 1))
netD:add(SpatialBatchNormalization(ndf * 4)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*4) x 8 x 8
netD:add(SpatialConvolution(ndf * 4, ndf * 8, 4, 4, 2, 2, 1, 1))
netD:add(SpatialBatchNormalization(ndf * 8)):add(nn.LeakyReLU(0.2, true))
-- state size: (ndf*8) x 4 x 4
netD:add(SpatialConvolution(ndf * 8, 1, 4, 4))
netD:add(nn.Sigmoid())
-- state size: 1 x 1 x 1
netD:add(nn.View(1):setNumInputDims(3))
-- state size: 1

netD:apply(weights_init)

---------------------------------------------------------------------------
-- Loss Metrics
---------------------------------------------------------------------------
local criterion = nn.BCECriterion()
local criterionMSE
if opt.wtl2~=0 then
  criterionMSE = nn.MSECriterion()
end

---------------------------------------------------------------------------
-- Setup Solver
---------------------------------------------------------------------------
print('LR of Gen is ',(opt.wtl2>0 and opt.wtl2<1) and 10 or 1,'times Adv')
optimStateG = {
   learningRate = (opt.wtl2>0 and opt.wtl2<1) and opt.lr*10 or opt.lr,
   beta1 = opt.beta1,
}
optimStateD = {
   learningRate = opt.lr,
   beta1 = opt.beta1,
}

---------------------------------------------------------------------------
-- Initialize data variables
---------------------------------------------------------------------------
local mask_global = torch.ByteTensor(opt.batchSize, opt.fineSize, opt.fineSize)
local input_ctx_vis = torch.Tensor(opt.batchSize, nc, opt.fineSize, opt.fineSize)
local input_ctx = torch.Tensor(opt.batchSize, nc, opt.fineSize, opt.fineSize)
local input_center = torch.Tensor(opt.batchSize, nc, opt.fineSize, opt.fineSize)
local input_real_center
if opt.wtl2~=0 then
    input_real_center = torch.Tensor(opt.batchSize, nc, opt.fineSize, opt.fineSize)
end
local noise = torch.Tensor(opt.batchSize, nz, 1, 1)
local label = torch.Tensor(opt.batchSize)
local errD, errG, errG_l2
local epoch_tm = torch.Timer()
local tm = torch.Timer()
local data_tm = torch.Timer()

if pcall(require, 'cudnn') and pcall(require, 'cunn') and opt.gpu>0 then
    print('Using CUDNN !')
end
if opt.gpu > 0 then
   require 'cunn'
   cutorch.setDevice(opt.gpu)
   input_ctx_vis = input_ctx_vis:cuda(); input_ctx = input_ctx:cuda();  input_center = input_center:cuda()
   noise = noise:cuda();  label = label:cuda()
   netG = util.cudnn(netG);     netD = util.cudnn(netD)
   netD:cuda();           netG:cuda();           criterion:cuda();      
   if opt.wtl2~=0 then
      criterionMSE:cuda(); input_real_center = input_real_center:cuda();
   end
end
print('NetG:',netG)
print('NetD:',netD)

-- Generating random pattern
local res = 0.06 -- the lower it is, the more continuous the output will be. 0.01 is too small and 0.1 is too large
local density = 0.25
local MAX_SIZE = 10000
local low_pattern = torch.Tensor(res*MAX_SIZE, res*MAX_SIZE):uniform(0,1):mul(255)
local pattern = image.scale(low_pattern, MAX_SIZE, MAX_SIZE,'bicubic')
low_pattern = nil
pattern:div(255);
pattern = torch.lt(pattern,density):byte()  -- 25% 1s and 75% 0s
pattern = pattern:byte()
print('...Random pattern generated')

local parametersD, gradParametersD = netD:getParameters()
local parametersG, gradParametersG = netG:getParameters()

if opt.display then disp = require 'display' end

noise_vis = noise:clone()
if opt.noisetype == 'uniform' then
    noise_vis:uniform(-1, 1)
elseif opt.noisetype == 'normal' then
    noise_vis:normal(0, 1)
end

---------------------------------------------------------------------------
-- Define generator and adversary closures
---------------------------------------------------------------------------
-- create closure to evaluate f(X) and df/dX of discriminator
local fDx = function(x)
   netD:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)
   netG:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)

   gradParametersD:zero()

   -- train with real
   data_tm:reset(); data_tm:resume()
   local real_ctx = data:getBatch()
   real_center = real_ctx  -- view
   input_center:copy(real_center)
   if opt.wtl2~=0 then
      input_real_center:copy(real_center)
   end

   -- get random mask
   local mask, wastedIter
   wastedIter = 0
   while true do
     local x = torch.uniform(1, MAX_SIZE-opt.fineSize)
     local y = torch.uniform(1, MAX_SIZE-opt.fineSize)
     mask = pattern[{{y,y+opt.fineSize-1},{x,x+opt.fineSize-1}}]  -- view, no allocation
     local area = mask:sum()*100./(opt.fineSize*opt.fineSize)
     if area>20 and area<30 then  -- want it to be approx 75% 0s and 25% 1s
        -- print('wasted tries: ',wastedIter)
        break
     end
     wastedIter = wastedIter + 1
   end
   torch.repeatTensor(mask_global,mask,opt.batchSize,1,1)

   real_ctx[{{},{1},{},{}}][mask_global] = 2*117.0/255.0 - 1.0
   real_ctx[{{},{2},{},{}}][mask_global] = 2*104.0/255.0 - 1.0
   real_ctx[{{},{3},{},{}}][mask_global] = 2*123.0/255.0 - 1.0
   input_ctx:copy(real_ctx)
   data_tm:stop()

   label:fill(real_label)
   local output
   if opt.conditionAdv then
      output = netD:forward({input_ctx,input_center})
   else
      output = netD:forward(input_center)
   end
   local errD_real = criterion:forward(output, label)
   local df_do = criterion:backward(output, label)
   if opt.conditionAdv then
      netD:backward({input_ctx,input_center}, df_do)
   else
      netD:backward(input_center, df_do)
   end
   
   -- train with fake
   if opt.noisetype == 'uniform' then -- regenerate random noise
       noise:uniform(-1, 1)
   elseif opt.noisetype == 'normal' then
       noise:normal(0, 1)
   end
   local fake
   if opt.noiseGen then
      fake = netG:forward({input_ctx,noise})
   else
      fake = netG:forward(input_ctx)
   end
   input_center:copy(fake)
   label:fill(fake_label)

   local output
   if opt.conditionAdv then
      output = netD:forward({input_ctx,input_center})
   else
      output = netD:forward(input_center)
   end
   local errD_fake = criterion:forward(output, label)
   local df_do = criterion:backward(output, label)
   if opt.conditionAdv then
      netD:backward({input_ctx,input_center}, df_do)
   else
      netD:backward(input_center, df_do)
   end

   errD = errD_real + errD_fake

   return errD, gradParametersD
end

-- create closure to evaluate f(X) and df/dX of generator
local fGx = function(x)
   netD:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)
   netG:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)

   gradParametersG:zero()

   --[[ the three lines below were already executed in fDx, so save computation
   noise:uniform(-1, 1) -- regenerate random noise
   local fake = netG:forward({input_ctx,noise})
   input_center:copy(fake) ]]--
   label:fill(real_label) -- fake labels are real for generator cost

   local output = netD.output -- netD:forward({input_ctx,input_center}) was already executed in fDx, so save computation
   errG = criterion:forward(output, label)
   local df_do = criterion:backward(output, label)
   local df_dg
   if opt.conditionAdv then
      df_dg = netD:updateGradInput({input_ctx,input_center}, df_do)
      df_dg = df_dg[2]     -- df_dg[2] because conditional GAN
   else
      df_dg = netD:updateGradInput(input_center, df_do)
   end

   local errG_total = errG
   if opt.wtl2~=0 then
      errG_l2 = criterionMSE:forward(input_center, input_real_center)
      local df_dg_l2 = criterionMSE:backward(input_center, input_real_center)

      if opt.useOverlapPred==0 then
        if (opt.wtl2>0 and opt.wtl2<1) then
          df_dg:mul(1-opt.wtl2):add(opt.wtl2,df_dg_l2)
          errG_total = (1-opt.wtl2)*errG + opt.wtl2*errG_l2
        else
          df_dg:add(opt.wtl2,df_dg_l2)
          errG_total = errG + opt.wtl2*errG_l2
        end
      else
        local overlapL2Weight = 10
        local wtl2Matrix = df_dg_l2:clone():fill(overlapL2Weight*opt.wtl2)
        for i=1,3 do
          wtl2Matrix[{{},{i},{},{}}][mask_global] = opt.wtl2
        end
        if (opt.wtl2>0 and opt.wtl2<1) then
          df_dg:mul(1-opt.wtl2):addcmul(1,wtl2Matrix,df_dg_l2)
          errG_total = (1-opt.wtl2)*errG + opt.wtl2*errG_l2
        else
          df_dg:addcmul(1,wtl2Matrix,df_dg_l2)
          errG_total = errG + opt.wtl2*errG_l2
        end
      end
   end

   if opt.noiseGen then
      netG:backward({input_ctx,noise}, df_dg)
   else
      netG:backward(input_ctx, df_dg)
   end

   return errG_total, gradParametersG
end

---------------------------------------------------------------------------
-- Train Context Encoder
---------------------------------------------------------------------------
for epoch = 1, opt.niter do
   epoch_tm:reset()
   local counter = 0
   for i = 1, math.min(data:size(), opt.ntrain), opt.batchSize do
      tm:reset()
      -- (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
      optim.adam(fDx, parametersD, optimStateD)

      -- (2) Update G network: maximize log(D(G(z)))
      optim.adam(fGx, parametersG, optimStateG)

      -- display
      counter = counter + 1
      if counter % opt.display_iter == 0 and opt.display then
          local real_ctx = data:getBatch()
          -- disp.image(real_ctx, {win=opt.display_id * 6, title=opt.name})

          local mask, wastedIter
          wastedIter = 0
          while true do
           local x = torch.uniform(1, MAX_SIZE-opt.fineSize)
           local y = torch.uniform(1, MAX_SIZE-opt.fineSize)
           mask = pattern[{{y,y+opt.fineSize-1},{x,x+opt.fineSize-1}}]  -- view, no allocation
           local area = mask:sum()*100./(opt.fineSize*opt.fineSize)
           if area>20 and area<30 then  -- want it to be approx 75% 0s and 25% 1s
              -- print('wasted tries: ',wastedIter)
              break
           end
           wastedIter = wastedIter + 1
          end
          mask=torch.repeatTensor(mask,opt.batchSize,1,1)

          real_ctx[{{},{1},{},{}}][mask] = 2*117.0/255.0 - 1.0
          real_ctx[{{},{2},{},{}}][mask] = 2*104.0/255.0 - 1.0
          real_ctx[{{},{3},{},{}}][mask] = 2*123.0/255.0 - 1.0
          input_ctx_vis:copy(real_ctx)

          local fake
          if opt.noiseGen then
            fake = netG:forward({input_ctx_vis,noise_vis})
          else
            fake = netG:forward(input_ctx_vis)
          end
          disp.image(fake, {win=opt.display_id, title=opt.name})
          
          real_ctx[{{},{1},{},{}}][mask] = 1.0
          real_ctx[{{},{2},{},{}}][mask] = 1.0
          real_ctx[{{},{3},{},{}}][mask] = 1.0
          disp.image(real_ctx, {win=opt.display_id * 3, title=opt.name})
      end

      -- logging
      if ((i-1) / opt.batchSize) % 1 == 0 then
         print(('Epoch: [%d][%8d / %8d]\t Time: %.3f  DataTime: %.3f  '
                   .. '  Err_G_L2: %.4f   Err_G: %.4f  Err_D: %.4f'):format(
                 epoch, ((i-1) / opt.batchSize),
                 math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize),
                 tm:time().real, data_tm:time().real, errG_l2 or -1,
                 errG and errG or -1, errD and errD or -1))
      end
   end
   paths.mkdir('checkpoints')
   parametersD, gradParametersD = nil, nil -- nil them to avoid spiking memory
   parametersG, gradParametersG = nil, nil
   if epoch % 20 == 0 then
      util.save('checkpoints/' .. opt.name .. '_' .. epoch .. '_net_G.t7', netG, opt.gpu)
      util.save('checkpoints/' .. opt.name .. '_' .. epoch .. '_net_D.t7', netD, opt.gpu)
   end
   parametersD, gradParametersD = netD:getParameters() -- reflatten the params and get them
   parametersG, gradParametersG = netG:getParameters()
   print(('End of epoch %d / %d \t Time Taken: %.3f'):format(
            epoch, opt.niter, epoch_tm:time().real))
end

Advertisements
Loading...

We use cookies to provide and improve our services. By using our site, you consent to our Cookies Policy.