python - 使 tfp.sts.fit_with_hmc 更快

标签 python tensorflow tensorflow-probability

这是可能尝试学习如何使用 Tensorflow Probability 的一部分

我加载了代表每小时电能消耗的 1368 个值的时间序列。

我会使用季节性/自回归模型来生成一些预测。

此时下面的代码可以工作,但是执行速度非常慢。 这是可以预料的吗?我可以采取不同的措施来使其更快吗?

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_probability as tfp  
from tensorflow_probability import distributions as tfd
from tensorflow_probability import sts

print('tensorflow version', tf.version.VERSION)
# tensorflow version 2.0.0
print('tensorflow executing eagerly', tf.executing_eagerly())
# tensorflow executing eagerly True


# Load time series
time_series = [...] # shape (1368,)
# Extract observations used for fitting the model will use 1344 
# observation to forecast the next 24
observed_time_series = time_series[:-24] # shape (1344,)

# Define model
hour_of_day_effect = tfp.sts.Seasonal( num_seasons=24, observed_time_series=observed_time_series, name='hour_of_day_effect')
day_of_week_effect = tfp.sts.Seasonal( num_seasons=7, num_steps_per_season=24, observed_time_series=observed_time_series, name='day_of_week_effect')
residual_level = tfp.sts.Autoregressive( order=1, observed_time_series=observed_time_series, name='residual')

model = tfp.sts.Sum([hour_of_day_effect, day_of_week_effect, residual_level], observed_time_series=observed_time_series)


# Fit the model
samples, kernel_results = tfp.sts.fit_with_hmc(model=model, observed_time_series=observed_time_series)

# Make predicitons
forecast_dist = tfp.sts.forecast(model, observed_time_series,
                                 parameter_samples=samples,
                                 num_steps_forecast=24)


# Read predicions
forecast_mean = forecast_dist.mean().numpy()
forecast_scale = forecast_dist.stddev().numpy()
forecast_samples = forecast_dist.sample(10).numpy()

如果有人想尝试这些就是观察结果

[166.8, 160.5, 155.7, 156.6, 157.5, 155.4, 152.4, 150.3, 146.7, 141.9, 138.6, 124.2, 135.3, 134.4, 140.7, 144.9, 150.3, 148.8, 148.5, 147.6, 149.7, 152.4, 199.8, 306.3, 308.1, 306.6, 313.8, 308.1, 319.5, 310.5, 647.1, 1002.9, 1191.0, 1014.3, 1074.9, 921.6, 759.3, 940.5, 946.5, 984.9, 1014.6, 1016.4, 868.2, 478.5, 664.2, 647.7, 627.6, 678.9, 614.4, 660.0, 679.8, 559.8, 570.6, 561.3, 630.6, 1049.1, 1048.5, 1014.3, 1011.9, 903.9, 1014.0, 1108.5, 1240.8, 979.2, 852.3, 819.3, 558.9, 451.2, 613.5, 605.4, 631.2, 647.4, 607.2, 632.1, 580.2, 531.6, 579.3, 556.5, 546.0, 850.5, 1122.9, 1149.0, 1032.0, 977.1, 994.8, 1017.0, 993.3, 873.3, 880.5, 736.2, 580.5, 517.8, 641.7, 593.4, 621.0, 626.7, 660.0, 655.5, 595.5, 526.8, 605.4, 560.7, 611.1, 944.4, 1151.7, 1154.7, 1143.9, 999.6, 1006.2, 1147.2, 1058.1, 949.8, 726.3, 566.4, 582.6, 504.3, 640.2, 660.9, 687.9, 651.0, 634.8, 639.0, 621.6, 559.2, 654.3, 571.5, 663.0, 1005.0, 1130.7, 1101.0, 1152.9, 1076.4, 1003.5, 1128.3, 780.6, 774.6, 737.1, 506.4, 537.0, 399.0, 619.8, 578.1, 567.9, 597.0, 597.0, 632.7, 620.7, 533.7, 615.6, 510.6, 504.0, 535.2, 448.2, 350.7, 411.3, 328.2, 222.9, 225.3, 200.7, 179.4, 154.8, 160.8, 155.1, 153.9, 157.2, 161.4, 162.0, 162.0, 160.5, 162.6, 164.1, 165.6, 165.6, 164.7, 158.4, 153.9, 147.9, 138.3, 137.4, 134.4, 136.2, 140.1, 143.7, 141.6, 145.5, 149.4, 152.1, 154.8, 156.9, 159.3, 207.3, 318.3, 369.6, 361.8, 358.5, 363.0, 356.4, 367.8, 691.8, 1048.8, 1191.9, 1305.0, 1288.8, 1017.3, 1165.5, 1252.5, 1065.9, 1172.4, 880.8, 817.5, 719.7, 541.8, 690.3, 663.0, 677.4, 687.0, 721.2, 705.0, 733.8, 527.4, 671.7, 613.5, 761.1, 967.2, 1083.9, 1006.5, 1142.7, 990.3, 1047.9, 1269.9, 1139.1, 910.5, 672.6, 623.1, 585.9, 486.0, 611.1, 605.7, 648.3, 642.9, 624.0, 639.9, 642.9, 501.0, 639.6, 599.4, 612.6, 923.7, 1098.0, 1131.9, 1183.8, 1068.9, 1111.2, 1180.8, 1126.2, 1124.4, 788.1, 807.3, 663.9, 442.5, 682.2, 681.6, 743.1, 691.5, 694.5, 734.4, 727.2, 516.3, 696.6, 659.1, 742.8, 975.6, 1113.0, 1180.5, 1140.0, 1154.1, 802.2, 1245.6, 1117.2, 1207.8, 1070.7, 870.6, 661.5, 566.1, 726.6, 657.3, 696.9, 631.8, 527.4, 449.1, 447.3, 691.8, 666.6, 641.7, 714.9, 1009.5, 1141.2, 1272.0, 1264.2, 1118.1, 936.9, 1319.1, 1167.9, 967.2, 907.5, 704.1, 562.5, 397.8, 554.7, 500.4, 524.1, 244.8, 195.0, 187.2, 186.0, 177.9, 166.5, 163.5, 160.8, 157.5, 154.2, 150.6, 149.7, 147.0, 143.7, 139.2, 139.8, 139.5, 142.8, 145.5, 144.9, 137.1, 138.6, 143.1, 144.9, 145.2, 146.7, 144.6, 144.9, 147.0, 147.9, 147.9, 146.4, 144.0, 141.3, 138.6, 138.0, 132.6, 132.6, 136.2, 137.4, 140.4, 141.3, 147.0, 147.6, 149.1, 151.8, 157.2, 186.0, 315.6, 316.2, 312.0, 300.3, 289.8, 294.9, 306.3, 388.2, 798.3, 1233.9, 1274.7, 1303.2, 1095.9, 1060.2, 1199.7, 1098.6, 894.3, 927.9, 647.4, 554.7, 423.0, 420.6, 368.4, 369.6, 360.3, 337.5, 324.6, 316.8, 312.9, 311.1, 301.8, 389.4, 744.9, 837.6, 1230.6, 1135.8, 1016.7, 1078.8, 1253.4, 1125.9, 915.3, 912.0, 839.7, 642.9, 492.0, 663.3, 632.4, 362.7, 353.1, 344.1, 343.8, 331.2, 323.7, 307.2, 304.8, 636.9, 1046.7, 1037.4, 1086.9, 880.2, 968.4, 1002.3, 995.7, 1131.0, 943.8, 910.2, 732.6, 580.2, 514.8, 587.4, 606.6, 366.9, 373.5, 369.6, 357.0, 336.6, 330.3, 330.3, 323.7, 596.4, 944.7, 1134.6, 1173.3, 1148.1, 934.5, 991.8, 1081.5, 1168.8, 920.1, 671.4, 665.7, 621.6, 567.0, 627.0, 596.1, 359.4, 363.3, 357.0, 340.8, 321.0, 325.8, 311.4, 296.1, 684.3, 1080.6, 1169.1, 1118.7, 1005.6, 921.0, 901.5, 845.1, 1088.7, 1146.3, 952.5, 760.2, 643.8, 592.5, 703.5, 625.5, 309.9, 305.4, 294.6, 297.3, 288.3, 276.0, 267.9, 260.1, 285.3, 665.7, 706.8, 621.3, 561.6, 451.8, 276.3, 242.4, 203.7, 191.7, 190.8, 188.1, 186.9, 185.1, 183.3, 186.3, 188.1, 189.6, 186.9, 176.1, 174.0, 171.0, 171.6, 170.7, 168.9, 167.7, 165.3, 156.9, 153.6, 151.5, 149.7, 146.1, 144.3, 145.8, 148.8, 152.1, 160.8, 160.2, 159.0, 164.1, 209.7, 332.7, 332.7, 330.3, 329.1, 340.5, 327.9, 333.3, 654.6, 1032.6, 1190.1, 1114.8, 1175.7, 981.9, 1065.0, 1223.7, 987.3, 1087.8, 857.4, 867.3, 558.3, 484.2, 575.7, 478.2, 344.7, 332.1, 340.5, 330.0, 318.9, 310.8, 308.4, 308.4, 564.0, 925.8, 1119.9, 1041.6, 1089.6, 1046.7, 964.8, 1084.8, 1025.4, 1147.8, 894.9, 637.5, 603.6, 516.6, 643.8, 536.4, 324.6, 330.0, 318.6, 315.0, 300.9, 290.7, 291.0, 291.0, 626.1, 960.3, 1035.3, 1300.2, 1202.1, 1042.2, 955.5, 1051.8, 1004.4, 970.8, 728.7, 642.6, 576.6, 487.2, 617.7, 573.3, 351.9, 334.8, 326.7, 311.1, 294.9, 291.3, 285.9, 292.8, 617.1, 891.9, 1033.8, 1281.3, 1053.0, 1090.2, 841.8, 913.5, 1146.3, 954.3, 1160.4, 707.1, 556.2, 515.1, 607.2, 538.2, 323.7, 329.1, 326.1, 326.1, 305.7, 297.9, 293.4, 292.2, 626.4, 961.5, 1124.1, 1192.8, 1179.3, 889.8, 904.8, 1141.5, 1155.6, 1025.7, 784.8, 780.6, 609.6, 523.5, 696.9, 611.7, 249.3, 237.0, 230.7, 229.2, 225.6, 218.4, 213.3, 207.0, 240.0, 264.9, 324.0, 457.8, 671.7, 287.7, 251.7, 229.8, 208.8, 183.9, 172.5, 168.0, 165.6, 167.4, 163.8, 165.3, 153.3, 148.8, 147.9, 160.2, 162.3, 159.9, 158.1, 157.5, 154.2, 151.8, 150.3, 146.1, 141.0, 139.8, 138.9, 136.8, 135.6, 135.3, 141.0, 142.2, 147.0, 147.6, 150.9, 159.3, 251.4, 287.7, 284.7, 279.6, 279.3, 274.5, 269.1, 281.7, 642.6, 891.0, 882.6, 1244.7, 1292.1, 1031.7, 1034.4, 1079.7, 1082.1, 992.4, 726.3, 656.4, 624.9, 468.6, 671.4, 618.9, 654.0, 641.1, 660.9, 672.0, 634.2, 436.5, 585.3, 561.0, 649.5, 1024.8, 999.0, 1040.4, 1151.1, 966.9, 965.1, 1015.2, 853.2, 859.2, 621.9, 598.8, 540.0, 422.4, 591.9, 579.3, 568.5, 607.8, 574.8, 594.9, 548.7, 430.8, 534.9, 516.9, 545.1, 913.8, 1047.3, 1225.8, 1070.1, 1035.9, 887.7, 962.7, 1032.9, 962.7, 837.3, 617.7, 578.7, 474.3, 668.4, 647.7, 616.5, 647.1, 621.0, 604.2, 544.5, 355.5, 321.9, 321.0, 354.0, 689.1, 959.1, 1209.0, 1011.0, 799.5, 792.0, 987.6, 1190.4, 1211.1, 757.5, 744.6, 621.9, 518.1, 613.2, 594.9, 625.2, 648.9, 639.9, 631.5, 618.0, 417.3, 598.2, 551.7, 972.6, 916.2, 975.0, 1095.0, 1122.9, 978.6, 968.1, 1042.2, 814.5, 737.1, 638.7, 584.7, 612.9, 359.1, 611.4, 540.6, 517.8, 527.1, 492.6, 501.0, 513.9, 324.9, 471.0, 357.6, 207.0, 216.9, 214.8, 203.1, 207.3, 192.3, 193.8, 192.6, 167.1, 169.5, 175.2, 157.2, 154.2, 156.0, 159.3, 161.4, 155.7, 149.4, 147.0, 139.5, 135.6, 135.3, 134.7, 136.2, 132.3, 128.7, 128.4, 124.5, 121.8, 121.5, 118.5, 119.7, 115.2, 118.2, 120.3, 123.0, 121.2, 118.8, 120.3, 142.5, 274.2, 285.3, 283.5, 277.8, 283.2, 281.1, 262.5, 270.0, 578.4, 1009.2, 933.3, 1274.1, 1138.2, 966.6, 1014.9, 1126.5, 1017.0, 1128.9, 998.4, 927.0, 611.7, 504.9, 641.4, 598.5, 603.3, 604.8, 590.1, 611.4, 478.5, 622.2, 604.2, 554.7, 600.6, 953.1, 1062.9, 1150.5, 1262.4, 889.2, 985.2, 1039.2, 850.2, 851.7, 811.2, 612.6, 600.6, 521.7, 615.3, 582.3, 523.8, 314.1, 323.4, 314.4, 318.0, 296.1, 294.6, 286.2, 371.4, 641.4, 847.8, 927.6, 950.4, 1026.9, 899.1, 1041.0, 1085.7, 981.3, 900.6, 966.0, 825.9, 689.7, 843.0, 671.1, 580.5, 596.4, 600.3, 567.9, 473.4, 483.3, 527.1, 589.8, 362.4, 767.1, 795.3, 867.3, 846.6, 711.9, 827.7, 831.6, 838.8, 852.6, 682.8, 599.7, 343.5, 307.2, 303.6, 307.2, 325.5, 305.7, 302.7, 296.7, 299.1, 282.3, 269.7, 271.8, 327.6, 625.2, 743.1, 861.3, 955.5, 854.7, 824.7, 985.8, 747.6, 756.9, 634.2, 559.5, 337.8, 296.1, 299.4, 288.3, 275.1, 269.7, 261.0, 256.8, 255.0, 249.3, 240.6, 228.0, 252.3, 338.1, 614.4, 433.8, 543.0, 427.5, 255.9, 258.9, 232.5, 225.0, 225.0, 219.3, 145.2, 104.1, 107.1, 107.7, 109.5, 109.2, 106.8, 106.5, 106.8, 107.1, 106.5, 105.6, 101.4, 101.4, 99.3, 96.6, 95.1, 94.2, 92.1, 93.9, 97.2, 98.4, 101.7, 100.5, 103.8, 106.2, 108.0, 114.6, 190.8, 290.4, 300.0, 290.1, 274.8, 269.1, 260.4, 265.2, 326.1, 618.3, 782.1, 852.3, 834.6, 686.1, 733.5, 760.5, 675.3, 705.9, 697.2, 649.5, 504.0, 447.9, 386.1, 388.2, 629.7, 684.9, 639.9, 643.8, 591.3, 573.9, 638.7, 621.3, 678.3, 892.2, 859.5, 1011.3, 1100.1, 964.2, 892.5, 1309.5, 1225.2, 979.8, 879.9, 670.2, 573.3, 534.9, 675.9, 588.6, 591.0, 570.0, 560.1, 524.7, 605.4, 549.9, 608.7, 593.4, 612.6, 748.2, 1161.0, 1360.8, 1260.3, 1055.1, 1070.4, 1176.0, 1069.5, 837.0, 709.5, 595.8, 582.6, 422.4, 643.8, 611.1, 635.4, 637.2, 590.1, 628.2, 606.9, 536.7, 580.2, 573.9, 582.0, 889.5, 1035.3, 1139.7, 979.8, 1007.7, 995.4, 1029.3, 1041.3, 843.0, 816.3, 702.3, 555.3, 522.9, 662.7, 641.4, 630.3, 621.6, 609.6, 631.8, 596.7, 540.3, 605.7, 594.3, 607.8, 944.1, 1200.0, 1113.0, 1132.8, 900.6, 918.0, 1063.8, 1040.1, 894.6, 902.1, 585.6, 567.3, 492.6, 540.3, 521.1, 503.7, 513.6, 476.7, 511.8, 502.8, 453.0, 495.6, 483.3, 529.2, 549.6, 513.9, 546.3, 505.5, 447.6, 400.5, 423.6, 192.6, 177.0, 170.1, 166.5, 160.8, 158.7, 146.4, 146.7, 149.4, 149.7, 148.5, 149.7, 150.3, 150.0, 146.1, 141.9, 137.4, 135.6, 133.2, 133.5, 129.0, 127.8, 125.7, 126.9, 123.6, 124.5, 128.1, 127.8, 130.8, 135.3, 138.6, 142.8, 180.6, 279.6, 312.3, 306.0, 308.1, 301.8, 298.8, 318.9, 624.6, 971.1, 1068.6, 1181.4, 1173.0, 964.5, 998.7, 1062.3, 1038.3, 1148.4, 749.4, 769.5, 739.5, 465.0, 379.8, 383.1, 628.8, 662.7, 633.6, 687.0, 700.5, 591.9, 632.1, 640.5, 642.3, 1128.6, 1116.9, 1007.7, 1213.8, 1011.9, 886.2, 1010.1, 1036.5, 1179.0, 837.9, 606.0, 550.5, 429.9, 603.0, 577.8, 589.2, 625.5, 585.3, 635.7, 611.7, 523.5, 560.7, 553.5, 578.4, 816.0, 926.1, 1185.6, 1153.5, 954.0, 1083.6, 1236.9, 1121.1, 1032.0, 813.0, 598.8, 537.3, 477.9, 543.9, 374.7, 477.6, 387.6, 380.4, 357.6, 354.6, 346.2, 336.0, 338.4, 358.5, 700.2, 876.0, 777.9, 1008.3, 974.4, 989.4, 1017.0, 1246.8, 946.2, 861.9, 756.9, 564.3, 479.4, 605.1, 590.7, 667.2, 672.0, 630.9, 648.3, 588.3, 560.4, 585.6, 587.4, 622.5, 917.4, 1058.7, 1274.1, 1172.1, 1041.0, 1035.0, 1227.9, 1012.2, 980.7, 870.9, 699.0, 544.8, 456.9, 600.0, 548.1, 520.2, 530.4, 513.9, 554.1, 509.7, 445.2, 502.8, 492.0, 512.1, 489.6, 451.5, 511.2, 509.7, 438.0, 449.7, 440.7, 168.0, 165.6, 157.2, 166.2, 156.3, 156.9, 156.0, 152.4, 151.8, 151.5, 153.6, 147.9, 147.9, 148.5, 149.1, 148.2, 145.8, 144.3, 142.5, 139.2, 138.3, 137.7, 138.0, 139.5, 137.4, 140.4, 138.6, 138.9, 141.9, 146.7, 148.5, 153.6, 184.8, 324.3]```

最佳答案

如果你把它放在一个

中应该会快一点
@tf.function(experimental_compile=True)
def f(...):
  return ...

我们正在继续开发这个包。请随意在 GitHub 上提出问题以更深入地探讨这个问题。

关于python - 使 tfp.sts.fit_with_hmc 更快,我们在Stack Overflow上找到一个类似的问题: https://stackoverflow.com/questions/58239857/

相关文章:

python selenium 选择值与正则表达式匹配的下拉选项

python - 关于在 gzip 文件上使用 seek

tensorflow - 如何使用 tensorflow 占位符在 get_collection 中使用

math - tf.truncated_normal 和 tf.random_normal 有什么区别?

python - Keras:val_loss 和 val_accuracy 没有改变

python - 当从 tensorflow 概率的分布中采样时,张量是不可散列的错误(在colab上)

tensorflow-probability - Tensorflow 概率 Logistic 回归示例

Python:从 STFT 重建音频文件

python - 根据原始字典中的 N 个键创建 N 个新字典

python - 在 tensorflow 概率中指定 DirichletMultinomial