Adam TF and SymJAXΒΆ

In this example we demonstrate how to perform a simple optimization with Adam in TF and SymJAX

  • plot compare adam
  • Adam Optimization quadratic loss (-:TF, --:SJ), lr:0.1, lr:0.1
  • GD Optimization quadratic loss (-:TF, --:SJ), lr:0.1, lr:0.1

Out:

Placeholder(name=x, shape=(), dtype=float32, scope=/) Op(name=true_divide, fn=true_divide, shape=(), dtype=float32, scope=/ExponentialMovingAverage/)
[Variable(name=EMA, shape=(), dtype=float32, trainable=False, scope=/ExponentialMovingAverage/), Placeholder(name=x, shape=(), dtype=float32, scope=/), Variable(name=num_steps, shape=(), dtype=int32, trainable=False, scope=/ExponentialMovingAverage/)]
Placeholder(name=x, shape=(), dtype=float32, scope=/) Op(name=add, fn=add, shape=(), dtype=float32, scope=/ExponentialMovingAverage/)
[Placeholder(name=x, shape=(), dtype=float32, scope=/), Variable(name=EMA, shape=(), dtype=float32, trainable=False, scope=/ExponentialMovingAverage/)]

  0%|          | 0/400 [00:00<?, ?it/s]
 18%|#8        | 74/400 [00:00<00:00, 738.41it/s]
 42%|####1     | 167/400 [00:00<00:00, 786.52it/s]
 67%|######7   | 268/400 [00:00<00:00, 841.21it/s]
 90%|######### | 362/400 [00:00<00:00, 868.31it/s]
100%|##########| 400/400 [00:00<00:00, 897.45it/s]

  0%|          | 0/400 [00:00<?, ?it/s]
  0%|          | 1/400 [00:00<02:14,  2.96it/s]
 10%|#         | 41/400 [00:00<01:25,  4.22it/s]
 21%|##1       | 85/400 [00:00<00:52,  6.00it/s]
 32%|###2      | 130/400 [00:00<00:31,  8.52it/s]
 44%|####3     | 174/400 [00:00<00:18, 12.07it/s]
 55%|#####4    | 218/400 [00:00<00:10, 17.04it/s]
 66%|######6   | 266/400 [00:00<00:05, 23.97it/s]
 78%|#######8  | 314/400 [00:01<00:02, 33.53it/s]
 91%|######### | 363/400 [00:01<00:00, 46.51it/s]
100%|##########| 400/400 [00:01<00:00, 326.94it/s]

  0%|          | 0/400 [00:00<?, ?it/s]
 18%|#8        | 74/400 [00:00<00:00, 734.15it/s]
 44%|####4     | 176/400 [00:00<00:00, 800.57it/s]
 70%|######9   | 278/400 [00:00<00:00, 854.12it/s]
 95%|#########4| 379/400 [00:00<00:00, 895.55it/s]
100%|##########| 400/400 [00:00<00:00, 944.07it/s]

  0%|          | 0/400 [00:00<?, ?it/s]
  0%|          | 1/400 [00:00<02:10,  3.06it/s]
 13%|#3        | 52/400 [00:00<01:19,  4.37it/s]
 24%|##4       | 96/400 [00:00<00:48,  6.21it/s]
 36%|###5      | 142/400 [00:00<00:29,  8.82it/s]
 48%|####7     | 190/400 [00:00<00:16, 12.50it/s]
 60%|######    | 240/400 [00:00<00:09, 17.67it/s]
 72%|#######2  | 289/400 [00:00<00:04, 24.85it/s]
 84%|########3 | 335/400 [00:01<00:01, 34.70it/s]
 96%|#########5| 382/400 [00:01<00:00, 48.04it/s]
100%|##########| 400/400 [00:01<00:00, 342.04it/s]

  0%|          | 0/400 [00:00<?, ?it/s]
 18%|#8        | 74/400 [00:00<00:00, 733.99it/s]
 44%|####3     | 174/400 [00:00<00:00, 796.45it/s]
 69%|######8   | 275/400 [00:00<00:00, 849.44it/s]
 94%|#########3| 376/400 [00:00<00:00, 890.69it/s]
100%|##########| 400/400 [00:00<00:00, 933.98it/s]

  0%|          | 0/400 [00:00<?, ?it/s]
  0%|          | 1/400 [00:00<02:18,  2.87it/s]
 12%|#1        | 46/400 [00:00<01:26,  4.09it/s]
 22%|##2       | 90/400 [00:00<00:53,  5.82it/s]
 34%|###4      | 137/400 [00:00<00:31,  8.27it/s]
 44%|####3     | 175/400 [00:00<00:19, 11.71it/s]
 56%|#####5    | 222/400 [00:00<00:10, 16.54it/s]
 67%|######7   | 268/400 [00:00<00:05, 23.27it/s]
 78%|#######8  | 314/400 [00:01<00:02, 32.53it/s]
 89%|########8 | 355/400 [00:01<00:01, 44.94it/s]
100%|##########| 400/400 [00:01<00:00, 319.40it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  8%|8         | 81/1000 [00:00<00:01, 805.73it/s]
 18%|#8        | 183/1000 [00:00<00:00, 859.05it/s]
 28%|##8       | 285/1000 [00:00<00:00, 900.36it/s]
 39%|###8      | 389/1000 [00:00<00:00, 937.17it/s]
 48%|####8     | 483/1000 [00:00<00:00, 936.13it/s]
 57%|#####7    | 574/1000 [00:00<00:00, 922.51it/s]
 67%|######6   | 668/1000 [00:00<00:00, 926.86it/s]
 76%|#######6  | 764/1000 [00:00<00:00, 934.54it/s]
 86%|########6 | 862/1000 [00:00<00:00, 947.60it/s]
 95%|#########5| 954/1000 [00:01<00:00, 936.36it/s]
100%|##########| 1000/1000 [00:01<00:00, 947.93it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  0%|          | 1/1000 [00:00<05:21,  3.11it/s]
  4%|4         | 44/1000 [00:00<03:35,  4.43it/s]
  9%|9         | 92/1000 [00:00<02:24,  6.30it/s]
 13%|#3        | 133/1000 [00:00<01:37,  8.94it/s]
 17%|#7        | 174/1000 [00:00<01:05, 12.65it/s]
 22%|##1       | 219/1000 [00:00<00:43, 17.85it/s]
 27%|##6       | 267/1000 [00:00<00:29, 25.10it/s]
 31%|###1      | 313/1000 [00:01<00:19, 35.03it/s]
 36%|###5      | 356/1000 [00:01<00:13, 48.33it/s]
 40%|###9      | 398/1000 [00:01<00:09, 65.63it/s]
 44%|####4     | 444/1000 [00:01<00:06, 88.34it/s]
 49%|####8     | 487/1000 [00:01<00:04, 115.87it/s]
 54%|#####3    | 537/1000 [00:01<00:03, 150.37it/s]
 58%|#####8    | 585/1000 [00:01<00:02, 189.08it/s]
 63%|######3   | 631/1000 [00:01<00:01, 227.50it/s]
 68%|######7   | 676/1000 [00:01<00:01, 265.01it/s]
 72%|#######2  | 721/1000 [00:01<00:00, 289.24it/s]
 77%|#######6  | 767/1000 [00:02<00:00, 325.02it/s]
 81%|########1 | 814/1000 [00:02<00:00, 356.74it/s]
 86%|########5 | 859/1000 [00:02<00:00, 371.88it/s]
 90%|######### | 905/1000 [00:02<00:00, 394.16it/s]
 95%|#########4| 949/1000 [00:02<00:00, 402.38it/s]
100%|#########9| 995/1000 [00:02<00:00, 415.68it/s]
100%|##########| 1000/1000 [00:02<00:00, 383.87it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  8%|8         | 81/1000 [00:00<00:01, 808.12it/s]
 18%|#7        | 179/1000 [00:00<00:00, 850.89it/s]
 27%|##7       | 272/1000 [00:00<00:00, 872.83it/s]
 36%|###6      | 365/1000 [00:00<00:00, 887.83it/s]
 46%|####5     | 459/1000 [00:00<00:00, 902.70it/s]
 55%|#####5    | 553/1000 [00:00<00:00, 911.91it/s]
 65%|######4   | 647/1000 [00:00<00:00, 919.47it/s]
 73%|#######3  | 733/1000 [00:00<00:00, 866.19it/s]
 82%|########2 | 823/1000 [00:00<00:00, 874.13it/s]
 92%|#########1| 916/1000 [00:01<00:00, 887.85it/s]
100%|##########| 1000/1000 [00:01<00:00, 907.75it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  0%|          | 1/1000 [00:00<04:51,  3.43it/s]
  4%|4         | 45/1000 [00:00<03:15,  4.88it/s]
 10%|9         | 97/1000 [00:00<02:10,  6.95it/s]
 15%|#5        | 150/1000 [00:00<01:26,  9.87it/s]
 20%|##        | 200/1000 [00:00<00:57, 13.97it/s]
 25%|##5       | 254/1000 [00:00<00:37, 19.74it/s]
 31%|###       | 309/1000 [00:00<00:24, 27.77it/s]
 36%|###5      | 359/1000 [00:00<00:16, 38.75it/s]
 41%|####1     | 414/1000 [00:01<00:10, 53.73it/s]
 46%|####6     | 464/1000 [00:01<00:07, 73.31it/s]
 52%|#####2    | 522/1000 [00:01<00:04, 99.31it/s]
 57%|#####7    | 574/1000 [00:01<00:03, 130.63it/s]
 63%|######2   | 626/1000 [00:01<00:02, 167.68it/s]
 68%|######7   | 677/1000 [00:01<00:01, 201.99it/s]
 72%|#######2  | 725/1000 [00:01<00:01, 240.89it/s]
 77%|#######7  | 772/1000 [00:01<00:00, 280.01it/s]
 82%|########2 | 820/1000 [00:01<00:00, 319.29it/s]
 87%|########6 | 867/1000 [00:02<00:00, 349.68it/s]
 92%|#########1| 918/1000 [00:02<00:00, 384.49it/s]
 97%|#########6| 966/1000 [00:02<00:00, 405.70it/s]
100%|##########| 1000/1000 [00:02<00:00, 426.97it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  8%|7         | 79/1000 [00:00<00:01, 787.66it/s]
 18%|#8        | 183/1000 [00:00<00:00, 848.93it/s]
 28%|##8       | 283/1000 [00:00<00:00, 887.76it/s]
 38%|###7      | 377/1000 [00:00<00:00, 902.75it/s]
 48%|####7     | 478/1000 [00:00<00:00, 930.63it/s]
 58%|#####7    | 577/1000 [00:00<00:00, 946.38it/s]
 68%|######7   | 677/1000 [00:00<00:00, 959.54it/s]
 77%|#######7  | 770/1000 [00:00<00:00, 949.58it/s]
 86%|########6 | 864/1000 [00:00<00:00, 945.42it/s]
 96%|#########6| 960/1000 [00:01<00:00, 948.24it/s]
100%|##########| 1000/1000 [00:01<00:00, 956.63it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  0%|          | 1/1000 [00:00<05:00,  3.32it/s]
  5%|4         | 48/1000 [00:00<03:21,  4.73it/s]
  9%|9         | 93/1000 [00:00<02:14,  6.73it/s]
 14%|#4        | 140/1000 [00:00<01:29,  9.56it/s]
 19%|#8        | 188/1000 [00:00<00:59, 13.54it/s]
 24%|##3       | 236/1000 [00:00<00:39, 19.11it/s]
 29%|##8       | 287/1000 [00:00<00:26, 26.85it/s]
 33%|###3      | 330/1000 [00:01<00:17, 37.31it/s]
 37%|###7      | 374/1000 [00:01<00:12, 51.41it/s]
 42%|####1     | 418/1000 [00:01<00:08, 69.93it/s]
 47%|####6     | 470/1000 [00:01<00:05, 94.44it/s]
 52%|#####1    | 517/1000 [00:01<00:03, 124.08it/s]
 56%|#####6    | 563/1000 [00:01<00:02, 158.83it/s]
 61%|######1   | 610/1000 [00:01<00:01, 198.12it/s]
 66%|######5   | 658/1000 [00:01<00:01, 239.73it/s]
 71%|#######   | 708/1000 [00:01<00:01, 284.06it/s]
 76%|#######5  | 757/1000 [00:01<00:00, 324.02it/s]
 80%|########  | 805/1000 [00:02<00:00, 354.01it/s]
 85%|########5 | 852/1000 [00:02<00:00, 373.65it/s]
 90%|########9 | 898/1000 [00:02<00:00, 393.86it/s]
 94%|#########4| 944/1000 [00:02<00:00, 410.98it/s]
 99%|#########9| 990/1000 [00:02<00:00, 423.88it/s]
100%|##########| 1000/1000 [00:02<00:00, 406.74it/s]

  0%|          | 0/400 [00:00<?, ?it/s]
 25%|##5       | 100/400 [00:00<00:00, 998.80it/s]
 53%|#####2    | 211/400 [00:00<00:00, 1026.84it/s]
 80%|########  | 322/400 [00:00<00:00, 1050.32it/s]
100%|##########| 400/400 [00:00<00:00, 1089.28it/s]

  0%|          | 0/400 [00:00<?, ?it/s]
  0%|          | 1/400 [00:00<01:40,  3.96it/s]
 26%|##6       | 106/400 [00:00<00:52,  5.64it/s]
 53%|#####2    | 211/400 [00:00<00:23,  8.04it/s]
 80%|########  | 320/400 [00:00<00:06, 11.45it/s]
100%|##########| 400/400 [00:00<00:00, 645.53it/s]

  0%|          | 0/400 [00:00<?, ?it/s]
 24%|##4       | 97/400 [00:00<00:00, 965.63it/s]
 53%|#####3    | 213/400 [00:00<00:00, 1014.91it/s]
 80%|########  | 322/400 [00:00<00:00, 1036.07it/s]
100%|##########| 400/400 [00:00<00:00, 1073.61it/s]

  0%|          | 0/400 [00:00<?, ?it/s]
  0%|          | 1/400 [00:00<01:38,  4.04it/s]
 29%|##9       | 116/400 [00:00<00:49,  5.76it/s]
 59%|#####8    | 235/400 [00:00<00:20,  8.22it/s]
 84%|########4 | 338/400 [00:00<00:05, 11.70it/s]
100%|##########| 400/400 [00:00<00:00, 659.14it/s]

  0%|          | 0/400 [00:00<?, ?it/s]
 25%|##5       | 100/400 [00:00<00:00, 998.09it/s]
 55%|#####4    | 219/400 [00:00<00:00, 1048.03it/s]
 83%|########2 | 332/400 [00:00<00:00, 1065.31it/s]
100%|##########| 400/400 [00:00<00:00, 1098.47it/s]

  0%|          | 0/400 [00:00<?, ?it/s]
  0%|          | 1/400 [00:00<01:38,  4.06it/s]
 28%|##8       | 114/400 [00:00<00:49,  5.79it/s]
 57%|#####7    | 230/400 [00:00<00:20,  8.25it/s]
 86%|########5 | 342/400 [00:00<00:04, 11.75it/s]
100%|##########| 400/400 [00:00<00:00, 668.20it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  7%|7         | 70/1000 [00:00<00:01, 697.46it/s]
 16%|#6        | 161/1000 [00:00<00:01, 748.09it/s]
 25%|##5       | 253/1000 [00:00<00:00, 790.74it/s]
 35%|###4      | 347/1000 [00:00<00:00, 829.86it/s]
 44%|####4     | 443/1000 [00:00<00:00, 864.52it/s]
 54%|#####3    | 535/1000 [00:00<00:00, 877.39it/s]
 62%|######2   | 625/1000 [00:00<00:00, 883.61it/s]
 72%|#######1  | 716/1000 [00:00<00:00, 890.30it/s]
 80%|########  | 802/1000 [00:00<00:00, 878.07it/s]
 90%|########9 | 898/1000 [00:01<00:00, 900.22it/s]
 99%|#########9| 991/1000 [00:01<00:00, 907.98it/s]
100%|##########| 1000/1000 [00:01<00:00, 896.11it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  0%|          | 1/1000 [00:00<04:12,  3.95it/s]
 11%|#1        | 110/1000 [00:00<02:37,  5.64it/s]
 22%|##2       | 224/1000 [00:00<01:36,  8.04it/s]
 34%|###4      | 341/1000 [00:00<00:57, 11.45it/s]
 46%|####5     | 455/1000 [00:00<00:33, 16.29it/s]
 56%|#####6    | 565/1000 [00:00<00:18, 23.12it/s]
 68%|######7   | 676/1000 [00:00<00:09, 32.73it/s]
 79%|#######9  | 792/1000 [00:00<00:04, 46.20it/s]
 90%|########9 | 896/1000 [00:01<00:01, 64.76it/s]
100%|##########| 1000/1000 [00:01<00:00, 870.90it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  9%|9         | 92/1000 [00:00<00:00, 912.93it/s]
 20%|##        | 203/1000 [00:00<00:00, 963.16it/s]
 32%|###2      | 322/1000 [00:00<00:00, 1021.25it/s]
 44%|####4     | 441/1000 [00:00<00:00, 1064.24it/s]
 56%|#####6    | 560/1000 [00:00<00:00, 1098.69it/s]
 68%|######8   | 681/1000 [00:00<00:00, 1126.90it/s]
 79%|#######9  | 790/1000 [00:00<00:00, 1115.40it/s]
 90%|######### | 901/1000 [00:00<00:00, 1111.18it/s]
100%|##########| 1000/1000 [00:00<00:00, 1112.73it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  0%|          | 1/1000 [00:00<04:12,  3.95it/s]
 11%|#1        | 112/1000 [00:00<02:37,  5.63it/s]
 22%|##2       | 224/1000 [00:00<01:36,  8.03it/s]
 33%|###3      | 334/1000 [00:00<00:58, 11.44it/s]
 44%|####4     | 443/1000 [00:00<00:34, 16.27it/s]
 55%|#####5    | 554/1000 [00:00<00:19, 23.09it/s]
 67%|######7   | 671/1000 [00:00<00:10, 32.71it/s]
 78%|#######7  | 778/1000 [00:00<00:04, 46.12it/s]
 88%|########8 | 880/1000 [00:01<00:01, 64.64it/s]
 99%|#########9| 992/1000 [00:01<00:00, 90.11it/s]
100%|##########| 1000/1000 [00:01<00:00, 858.04it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  8%|8         | 82/1000 [00:00<00:01, 816.91it/s]
 18%|#7        | 177/1000 [00:00<00:00, 852.51it/s]
 27%|##7       | 270/1000 [00:00<00:00, 871.91it/s]
 36%|###6      | 360/1000 [00:00<00:00, 878.46it/s]
 45%|####5     | 453/1000 [00:00<00:00, 890.71it/s]
 55%|#####4    | 548/1000 [00:00<00:00, 907.21it/s]
 64%|######4   | 644/1000 [00:00<00:00, 920.12it/s]
 74%|#######4  | 740/1000 [00:00<00:00, 931.36it/s]
 83%|########3 | 833/1000 [00:00<00:00, 930.27it/s]
 92%|#########2| 924/1000 [00:01<00:00, 922.18it/s]
100%|##########| 1000/1000 [00:01<00:00, 914.25it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]
  0%|          | 1/1000 [00:00<04:21,  3.82it/s]
 11%|#1        | 112/1000 [00:00<02:42,  5.45it/s]
 22%|##2       | 225/1000 [00:00<01:39,  7.77it/s]
 33%|###3      | 331/1000 [00:00<01:00, 11.06it/s]
 44%|####4     | 443/1000 [00:00<00:35, 15.74it/s]
 56%|#####6    | 564/1000 [00:00<00:19, 22.36it/s]
 68%|######8   | 681/1000 [00:00<00:10, 31.68it/s]
 80%|########  | 800/1000 [00:00<00:04, 44.75it/s]
 91%|######### | 908/1000 [00:01<00:01, 62.81it/s]
100%|##########| 1000/1000 [00:01<00:00, 870.96it/s]
/home/vrael/anaconda3/lib/python3.7/site-packages/matplotlib/figure.py:445: UserWarning: Matplotlib is currently using agg, which is a non-GUI backend, so cannot show the figure.
  % get_backend())

import matplotlib.pyplot as plt

import symjax
import symjax.tensor as T
from symjax.nn import optimizers
import numpy as np
from tqdm import tqdm


BS = 1000
D = 500
X = np.random.randn(BS, D).astype("float32")
Y = X.dot(np.random.randn(D, 1).astype("float32")) + 2


def TF1(x, y, N, lr, model, preallocate=False):
    import tensorflow.compat.v1 as tf

    tf.compat.v1.disable_v2_behavior()
    tf.reset_default_graph()

    tf_input = tf.placeholder(dtype=tf.float32, shape=[BS, D])
    tf_output = tf.placeholder(dtype=tf.float32, shape=[BS, 1])

    np.random.seed(0)

    tf_W = tf.Variable(np.random.randn(D, 1).astype("float32"))
    tf_b = tf.Variable(
        np.random.randn(
            1,
        ).astype("float32")
    )

    tf_loss = tf.reduce_mean((tf.matmul(tf_input, tf_W) + tf_b - tf_output) ** 2)
    if model == "SGD":
        train_op = tf.train.GradientDescentOptimizer(lr).minimize(tf_loss)
    elif model == "Adam":
        train_op = tf.train.AdamOptimizer(lr).minimize(tf_loss)

    # initialize session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    losses = []
    for i in tqdm(range(N)):
        losses.append(
            sess.run([tf_loss, train_op], feed_dict={tf_input: x, tf_output: y})[0]
        )

    return losses


def TF_EMA(X):
    import tensorflow.compat.v1 as tf

    tf.compat.v1.disable_v2_behavior()
    tf.reset_default_graph()
    x = tf.placeholder("float32")
    # Create an ExponentialMovingAverage object
    ema = tf.train.ExponentialMovingAverage(decay=0.9)
    op = ema.apply([x])
    out = ema.average(x)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer(), feed_dict={x: X[0]})

    outputs = []
    for i in range(len(X)):
        sess.run(op, feed_dict={x: X[i]})
        outputs.append(sess.run(out))
    return outputs


def SJ_EMA(X, debias=True):
    symjax.current_graph().reset()
    x = T.Placeholder((), "float32", name="x")
    value = symjax.nn.schedules.ExponentialMovingAverage(x, 0.9, debias=debias)[0]
    print(x, value)
    print(symjax.current_graph().roots(value))
    train = symjax.function(x, outputs=value, updates=symjax.get_updates())
    outputs = []
    for i in range(len(X)):
        outputs.append(train(X[i]))
    return outputs


def SJ(x, y, N, lr, model, preallocate=False):
    symjax.current_graph().reset()
    sj_input = T.Placeholder(dtype=np.float32, shape=[BS, D])
    sj_output = T.Placeholder(dtype=np.float32, shape=[BS, 1])

    np.random.seed(0)

    sj_W = T.Variable(np.random.randn(D, 1).astype("float32"))
    sj_b = T.Variable(
        np.random.randn(
            1,
        ).astype("float32")
    )

    sj_loss = ((sj_input.dot(sj_W) + sj_b - sj_output) ** 2).mean()

    if model == "SGD":
        optimizers.SGD(sj_loss, lr)
    elif model == "Adam":
        optimizers.Adam(sj_loss, lr)
    train = symjax.function(
        sj_input, sj_output, outputs=sj_loss, updates=symjax.get_updates()
    )

    losses = []
    for i in tqdm(range(N)):
        losses.append(train(x, y))

    return losses


sample = np.random.randn(100)

plt.figure()
plt.plot(sample, label="Original signal", alpha=0.5)
plt.plot(TF_EMA(sample), c="orange", label="TF ema", linewidth=2, alpha=0.5)
plt.plot(SJ_EMA(sample), c="green", label="SJ ema (biased)", linewidth=2, alpha=0.5)
plt.plot(
    SJ_EMA(sample, False),
    c="green",
    linestyle="--",
    label="SJ ema (unbiased)",
    linewidth=2,
    alpha=0.5,
)
plt.legend()


plt.figure()
Ns = [400, 1000]
lrs = [0.001, 0.01, 0.1]
colors = ["r", "b", "g"]
for k, N in enumerate(Ns):
    plt.subplot(1, len(Ns), 1 + k)
    for c, lr in enumerate(lrs):
        loss = TF1(X, Y, N, lr, "Adam")
        plt.plot(loss, c=colors[c], linestyle="-", alpha=0.5)
        loss = SJ(X, Y, N, lr, "Adam")
        plt.plot(loss, c=colors[c], linestyle="--", alpha=0.5, linewidth=2)
        plt.title("lr:" + str(lr))
plt.suptitle("Adam Optimization quadratic loss (-:TF, --:SJ)")


plt.figure()
Ns = [400, 1000]
lrs = [0.001, 0.01, 0.1]
colors = ["r", "b", "g"]
for k, N in enumerate(Ns):
    plt.subplot(1, len(Ns), 1 + k)
    for c, lr in enumerate(lrs):
        loss = TF1(X, Y, N, lr, "SGD")
        plt.plot(loss, c=colors[c], linestyle="-", alpha=0.5)
        loss = SJ(X, Y, N, lr, "SGD")
        plt.plot(loss, c=colors[c], linestyle="--", alpha=0.5, linewidth=2)
        plt.title("lr:" + str(lr))
        plt.xlabel("steps")
plt.suptitle("GD Optimization quadratic loss (-:TF, --:SJ)")
plt.show()

Total running time of the script: ( 0 minutes 27.428 seconds)

Gallery generated by Sphinx-Gallery