import numpy as np
import h5py
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import metrics
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import BatchNormalization
from data_loader import *

nepochs = 300
n_in = 5
n_out = 1
train_fraction = 0.8
val_fraction = 0.0

#infilename = "/m100_work/IscrC_CD-DLS/simulations/full_dtb_newBH12b_csf3dR_z0.0"
infilename = "/m100_work/IscrC_CD-DLS/simulations/"
rho, T, vx, vy, vz, Bx, By, Bz, B2 = load_data(infilename, [0, 0, 0], [200, 200, 200])

n_cells = rho.size
n_train = n_cells

Xtrain = np.zeros((n_train, n_in))
Ytrain = np.zeros((n_train, n_out))

xxx = rho.flatten()
Xtrain[0:n_train,0] = xxx[0:n_train]
xxx = T.flatten()
Xtrain[0:n_train,1] = xxx[0:n_train]
xxx = vx.flatten()
Xtrain[0:n_train,2] = xxx[0:n_train]
xxx = vy.flatten()
Xtrain[0:n_train,3] = xxx[0:n_train]
xxx = vz.flatten()
Xtrain[0:n_train,4] = xxx[0:n_train]

if n_out == 3:
    xxx = Bx.flatten()
    Ytrain[0:n_train,0] = xxx[0:n_train]
    xxx = By.flatten()
    Ytrain[0:n_train,1] = xxx[0:n_train]
    xxx = Bz.flatten()
    Ytrain[0:n_train,2] = xxx[0:n_train]
if n_out == 1:
    xxx = B2.flatten()
    Ytrain[0:n_train,0] = xxx[0:n_train]


# define the keras model
n_nodes = 400

model = Sequential()
model.add(Dense(n_nodes, input_shape=(n_in,), activation='relu'))
#model.add(Dropout(0.2))
model.add(BatchNormalization())
model.add(Dense(n_nodes, activation='relu'))
#model.add(Dropout(0.2))
model.add(BatchNormalization())
model.add(Dense(n_nodes, activation='relu'))
#model.add(Dropout(0.2))
model.add(BatchNormalization())
model.add(Dense(n_nodes, activation='relu'))
#model.add(Dropout(0.2))
model.add(BatchNormalization())
model.add(Dense(n_nodes, activation='relu'))
#model.add(Dropout(0.2))
model.add(BatchNormalization())
model.add(Dense(n_out, activation='linear'))
model.summary()

print("Model training for number of outputs = ",n_out)

# compile the keras model
opt = Adam(lr=0.0001)
model.compile(loss='mse', optimizer=opt, metrics=[metrics.mae])

# train the network
model.fit(Xtrain, Ytrain, validation_split=0.2, epochs=nepochs, batch_size=500)

# save the trained network
ckptfile = 'models/trained_networ.ckpt'
model.save(ckptfile)

