tensorflow - TF 估计器梯度增强分类器在训练时突然停止

标签 tensorflow crash classification tensorflow-estimator

我已经使用 TF 示例代码训练了梯度增强分类器 https://www.tensorflow.org/tutorials/estimators/boosted_trees_model_understanding

但是, TF 估计器梯度提升分类器在训练时突然停止

我认为乞讨需要几个步骤,而不是突然停止而没有任何异常打印

我怎样才能找到Python崩溃的原因

很难找到停止的原因

资源:

lib : TF-gpu 1.13.1
cuda : 10.0
cudnn : 7.5

日志:

2019-04-15 16:40:26.175889: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1433] Found device 0 with properties: name: GeForce GTX 1060 6GB major: 6 minor: 1 memoryClockRate(GHz): 1.7845 pciBusID: 0000:07:00.0 totalMemory: 6.00GiB freeMemory: 4.97GiB 2019-04-15 16:40:26.182620: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1512] Adding visible gpu devices: 0 2019-04-15 16:40:26.832040: I tensorflow/core/common_runtime/gpu/gpu_device.cc:984] Device interconnect StreamExecutor with strength 1 edge matrix: 2019-04-15 16:40:26.835620: I tensorflow/core/common_runtime/gpu/gpu_device.cc:990] 0 2019-04-15 16:40:26.836840: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1003] 0: N 2019-04-15 16:40:26.838276: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 4716 MB memory) -> physical GPU (device: 0, name: GeForce GTX 1060 6GB, pci bus id: 0000:07:00.0, compute capability: 6.1) WARNING:tensorflow:From D:\python\lib\site-packages\tensorflow\python\training\saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file APIs to check for files with this prefix. WARNING:tensorflow:From D:\python\lib\site-packages\tensorflow\python\training\saver.py:1070: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version. Instructions for updating: Use standard file utilities to get mtimes. WARNING:tensorflow:Issue encountered when serializing resources. Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore. '_Resource' object has no attribute 'name' WARNING:tensorflow:Issue encountered when serializing resources. Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore. '_Resource' object has no attribute 'name'

D:\py>(刚刚完成训练)

trn = pd.read_csv('data/santander-customer-transaction-prediction/train.csv')
        tst = pd.read_csv('data/santander-customer-transaction-prediction/test.csv')



    #trn = upsample(trn[trn.target==0], trn[trn.target==1])
#   trn = downsample(trn[trn.target==0], trn[trn.target==1])


    features = trn.columns.values[2:202]
    target_name = trn.columns.values[1]
    train=trn[features]
    target=trn[target_name]

    NUM_EXAMPLES = len (target)
    print (NUM_EXAMPLES)

    feat1 = train.corrwith(target).sort_values().head(20).index
    feat2 = train.corrwith(target).sort_values().tail(20).index
    featonly = feat1.append(feat2)
    feat = featonly.append(pd.Index(['target']))

    train_origin, tt = train_test_split(trn, test_size=0.2)

    train = train_origin[featonly]
    target = train_origin[target_name]
    test = tst[featonly]

    target_name_tst = tst.columns.values[1]
    target_tst=tst[target_name_tst]

    val_origin=tt
    val_train = tt[featonly]
    val_target = tt[target_name]
    # Training and evaluation input functions.

    train_input_fn = make_input_fn(train, target)
    val_input_fn = make_input_fn(val_train, val_target)

    ttt=tf.estimator.inputs.pandas_input_fn(x=test,num_epochs=1,shuffle=False)


    del train,target,val_train,train_origin,trn,tst

    fc = tf.feature_column
    feature_columns = []
    for feature_name in featonly:
        feature_columns.append(fc.numeric_column(feature_name,dtype=tf.float32))
    #feature_columns



    #5
    #tf.logging.set_verbosity(tf.logging.INFO)
    #logging_hook = tf.train.LoggingTensorHook({"loss" : loss, "accuracy" : accuracy}, every_n_iter=10)

    params = {
      'n_trees': 50,
      'max_depth': 3,
      'n_batches_per_layer': 1,
      # You must enable center_bias = True to get DFCs. This will force the model to 
      # make an initial prediction before using any features (e.g. use the mean of 
      # the training labels for regression or log odds for classification when
      # using cross entropy loss).
      'center_bias': True
    }
#   config = tf.estimator.RunConfig().replace(keep_checkpoint_max = 1, 
 #                   log_step_count_steps=20, save_checkpoints_steps=20)

    est = tf.estimator.BoostedTreesClassifier(feature_columns, **params,model_dir='d:\py/model/')
    est.train(train_input_fn, max_steps=50)

--------------------------------------------------------已停止

metrics = est.evaluate(input_fn=val_input_fn,steps=1)

    results = est.predict(input_fn=ttt )
    result_list = list(results)


    classi = list(map(lambda x : x['classes'][0].decode("utf-8"), result_list))
    num = list(range(0,len(classi)))
    numi = list(map(lambda x : 'test_' + str(x),num))
    #df1 = pd.DataFrame(columns=('ID_code','target'))

    df_result = pd.DataFrame({'ID_code' : numi, 'target' : classi})

    df_result.to_csv('result/submission03.csv',index=False)

def make_input_fn(X, y, n_epochs=None, shuffle=True):
def input_fn():
    NUM_EXAMPLES = len(y)
    dataset = tf.data.Dataset.from_tensor_slices((dict(X), y))
  #  dataset = tf.data.Dataset.from_tensor_slices((X.to_dict(orient='list'), y))
    #if shuffle:
     #   dataset = dataset.shuffle(NUM_EXAMPLES)
    # For training, cycle thru dataset as many times as need (n_epochs=None).    
    dataset = (dataset.repeat(n_epochs).batch(NUM_EXAMPLES)) 
    return dataset
return input_fn

应显示评估结果

最佳答案

我认为该问题是由 GPU 内存溢出引起的。 您可以尝试根据您的GPU内存大小将'n_batches_per_layer'的值修改为更大的值。 我使用的是 6G GPU,值为 16。

关于tensorflow - TF 估计器梯度增强分类器在训练时突然停止,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/55684807/

相关文章:

windows - 使用 Python35 在 Windows 10 中安装 Tensorflow

python - 重新初始化数据集后损失回到起始值

python - Tensorflow 属性错误 : 'numpy.float32' object has no attribute 'value'

macos - Mac Snow Leopard上的Idea IntelliJ崩溃问题

android - 如果获取了唤醒锁并且我的应用程序崩溃了,我该怎么办?

android - Xamarin 和 Android - 退出方法时崩溃

python-3.x - 从 Cereal 图像中辨别有缺陷的 Cereal

python - Tensorflow:从任意长度的复杂张量中提取连续的补丁

python - SGDClassifier 每次为文本分类提供不同的准确度

neural-network - 分类神经网络的变量输入