{"id":547,"date":"2024-04-06T21:27:05","date_gmt":"2024-04-06T13:27:05","guid":{"rendered":"http:\/\/tobykskgd.life\/?p=547"},"modified":"2024-11-14T22:21:47","modified_gmt":"2024-11-14T14:21:47","slug":"22","status":"publish","type":"post","link":"https:\/\/tobykskgd.life\/index.php\/22\/","title":{"rendered":"\u674e\u5b8f\u6bc5\u673a\u5668\u5b66\u4e60\u8bfe\u7a0b\u7b14\u8bb0EP14"},"content":{"rendered":"\n<p>\u3010HW4\u3011Self attention0.1\u674e\u5b8f\u6bc52021\/2022\u6625\u673a\u5668\u5b66\u4e60\u8bfe\u7a0b\u7b14\u8bb0EP14(P45-P47)<\/p>\n\n\n\n<figure class=\"wp-block-image size-full\"><div class='fancybox-wrapper lazyload-container-unload' data-fancybox='post-images' href='https:\/\/tobykskgd.life\/wp-content\/uploads\/2024\/02\/\u5c4f\u5e55\u622a\u56fe-2024-02-05-213355.png'><img class=\"lazyload lazyload-style-1\" src=\"data:image\/svg+xml;base64,PCEtLUFyZ29uTG9hZGluZy0tPgo8c3ZnIHdpZHRoPSIxIiBoZWlnaHQ9IjEiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyIgc3Ryb2tlPSIjZmZmZmZmMDAiPjxnPjwvZz4KPC9zdmc+\"  loading=\"lazy\" decoding=\"async\" width=\"432\" height=\"218\" data-original=\"https:\/\/tobykskgd.life\/wp-content\/uploads\/2024\/02\/\u5c4f\u5e55\u622a\u56fe-2024-02-05-213355.png\" src=\"data:image\/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAAJcEhZcwAADsQAAA7EAZUrDhsAAAANSURBVBhXYzh8+PB\/AAffA0nNPuCLAAAAAElFTkSuQmCC\" alt=\"\" class=\"wp-image-37\"  sizes=\"auto, (max-width: 432px) 100vw, 432px\" \/><\/div><\/figure>\n\n\n\n<p>\u4ece\u4eca\u5929\u5f00\u59cb\u6211\u5c06\u5b66\u4e60\u674e\u5b8f\u6bc5\u6559\u6388\u7684\u673a\u5668\u5b66\u4e60\u89c6\u9891\uff0c\u4e0b\u9762\u662f\u8bfe\u7a0b\u7684\u8fde\u63a5<a href=\"https:\/\/www.bilibili.com\/video\/BV1Wv411h7kN\/?spm_id_from=333.337.search-card.all.click&amp;vd_source=fa9de75b9e5251495ee15fc767cb5892\">(\u5f3a\u63a8)\u674e\u5b8f\u6bc52021\/2022\u6625\u673a\u5668\u5b66\u4e60\u8bfe\u7a0b_\u54d4\u54e9\u54d4\u54e9_bilibili<\/a>\u3002\u4e00\u5171\u6709155\u4e2a\u89c6\u9891\uff0c\u4e89\u53d6\u90fd\u5b66\u4e60\u5b8c\u6210\u5427\u3002<\/p>\n\n\n\n<p>\u90a3\u4e48\u9996\u5148\u8fd9\u95e8\u8bfe\u7a0b\u9700\u8981\u6709\u4e00\u5b9a\u7684\u4ee3\u7801\u57fa\u7840\uff0c\u7b80\u5355\u5b66\u4e60\u4e00\u4e0bPython\u7684\u57fa\u672c\u7528\u6cd5\uff0c\u8fd8\u6709\u91cc\u9762\u7684NumPy\u5e93\u7b49\u7b49\u7684\u57fa\u672c\u77e5\u8bc6\u3002\u518d\u5c31\u662f\u6570\u5b66\u65b9\u9762\u7684\u57fa\u7840\u5566\uff0c\u5fae\u79ef\u5206\u3001\u7ebf\u6027\u4ee3\u6570\u548c\u6982\u7387\u8bba\u7684\u57fa\u7840\u90fd\u662f\u542c\u61c2\u8fd9\u95e8\u8bfe\u5fc5\u987b\u7684\u3002<\/p>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity\"\/>\n\n\n\n<p>\u672c\u6b21\u4f5c\u4e1a\u6709\u4e00\u4e2a\u5c0f\u7684\u63d0\u793a\u5427\uff0c\u5c31\u662f\u4e0d\u8981\u5c1d\u8bd5\u8fd0\u884c2021\u6216\u80052022\u7684\u4ee3\u7801\u4e86\uff0c\u90a3\u4e2a\u91cc\u9762\u7684\u6570\u636e\u96c6\u7684\u6e90\u5730\u5740\u5df2\u7ecf404\uff0c\u672c\u6765\u6211\u5728\u7f51\u7edc\u4e0a\u627e\u5230\u4e86\u8fd9\u4e00\u8bfe\u7684\u6570\u636e\u96c6\u60f3\u5c1d\u8bd5\u653e\u5230colab\u4e0a\u9762\uff0c\u4f46\u662f\u56e0\u4e3a\u7f51\u7edc\u6ce2\u52a8\u539f\u56e0\u4e00\u76f4\u4e0a\u4f20\u5931\u8d25\uff0c\u6700\u540e\u53d1\u73b02023\u6700\u65b0\u7684\u4f5c\u4e1a\u91cc\u7684\u4ee3\u7801\u7684\u6570\u636e\u96c6\u662f\u53ef\u4ee5\u4e0b\u8f7d\u7684\uff0c\u6240\u4ee5\u5c31\u4e0d\u7528\u9ebb\u70e6\u4e86\uff0c\u76f4\u63a5\u8fd0\u884c2023\u7684\u4f5c\u4e1a\u5c31\u884c\u4e86\u3002\u90a3\u672c\u6b21\u4f5c\u4e1a\u7684\u6570\u636e\u96c6\u76f8\u6bd4\u4ee5\u5f80\u5927\u4e86\u4e0d\u5c11\uff0c\u6211\u81ea\u5df1\u6ca1\u600e\u4e48\u6539\uff0c\u7528\u4e86\u5dee\u4e0d\u591a\u5feb5\u4e2a\u5c0f\u65f6\u7684\u65f6\u95f4\uff0c\u5982\u679c\u60f3\u8fc7\u66f4\u597d\u7684baseline\u4f30\u8ba1\u8981\u6570\u500d\u7684\u65f6\u95f4\u3002<\/p>\n\n\n\n<hr class=\"wp-block-separator has-alpha-channel-opacity is-style-dots\"\/>\n\n\n\n<p>\u4e0b\u8f7d\u6570\u636e\u96c6<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>!wget https:\/\/github.com\/googly-mingto\/ML2023HW4\/releases\/download\/data\/Dataset.tar.gz.partaa\n!wget https:\/\/github.com\/googly-mingto\/ML2023HW4\/releases\/download\/data\/Dataset.tar.gz.partab\n!wget https:\/\/github.com\/googly-mingto\/ML2023HW4\/releases\/download\/data\/Dataset.tar.gz.partac\n!wget https:\/\/github.com\/googly-mingto\/ML2023HW4\/releases\/download\/data\/Dataset.tar.gz.partad\n\n!cat Dataset.tar.gz.part* &gt; Dataset.tar.gz\n!rm Dataset.tar.gz.partaa\n!rm Dataset.tar.gz.partab\n!rm Dataset.tar.gz.partac\n!rm Dataset.tar.gz.partad\n# unzip the file\n!tar zxf Dataset.tar.gz\n!rm Dataset.tar.gz<\/code><\/pre>\n\n\n\n<p>&#8211;2024-04-06 07:44:39&#8211; <a href=\"https:\/\/github.com\/googly-mingto\/ML2023HW4\/releases\/download\/data\/Dataset.tar.gz.partaa\" target=\"_blank\" rel=\"noreferrer noopener\">https:\/\/github.com\/googly-mingto\/ML2023HW4\/releases\/download\/data\/Dataset.tar.gz.partaa<\/a> Resolving github.com (github.com)&#8230; 140.82.114.3 Connecting to github.com (github.com)|140.82.114.3|:443&#8230; connected. HTTP request sent, awaiting response&#8230; 302 Found Location: <a href=\"https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/7646b36b-6033-4a31-bac4-380c4d21d91e?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074440Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=010f486318683278533162f545daed7f736fc0053694c6e98b3ca1cf848037df&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partaa&amp;response-content-type=application%2Foctet-stream\" target=\"_blank\" rel=\"noreferrer noopener\">https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/7646b36b-6033-4a31-bac4-380c4d21d91e?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074440Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=010f486318683278533162f545daed7f736fc0053694c6e98b3ca1cf848037df&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partaa&amp;response-content-type=application%2Foctet-stream<\/a> [following] &#8211;2024-04-06 07:44:40&#8211; <a href=\"https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/7646b36b-6033-4a31-bac4-380c4d21d91e?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074440Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=010f486318683278533162f545daed7f736fc0053694c6e98b3ca1cf848037df&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partaa&amp;response-content-type=application%2Foctet-stream\" target=\"_blank\" rel=\"noreferrer noopener\">https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/7646b36b-6033-4a31-bac4-380c4d21d91e?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074440Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=010f486318683278533162f545daed7f736fc0053694c6e98b3ca1cf848037df&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partaa&amp;response-content-type=application%2Foctet-stream<\/a> Resolving objects.githubusercontent.com (objects.githubusercontent.com)&#8230; 185.199.110.133, 185.199.109.133, 185.199.108.133, &#8230; Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.110.133|:443&#8230; connected. HTTP request sent, awaiting response&#8230; 200 OK Length: 1560784333 (1.5G) [application\/octet-stream] Saving to: \u2018Dataset.tar.gz.partaa\u2019 Dataset.tar.gz.part 100%[===================&gt;] 1.45G 69.3MB\/s in 21s 2024-04-06 07:45:01 (70.1 MB\/s) &#8211; \u2018Dataset.tar.gz.partaa\u2019 saved [1560784333\/1560784333] &#8211;2024-04-06 07:45:01&#8211; <a href=\"https:\/\/github.com\/googly-mingto\/ML2023HW4\/releases\/download\/data\/Dataset.tar.gz.partab\" target=\"_blank\" rel=\"noreferrer noopener\">https:\/\/github.com\/googly-mingto\/ML2023HW4\/releases\/download\/data\/Dataset.tar.gz.partab<\/a> Resolving github.com (github.com)&#8230; 140.82.114.3 Connecting to github.com (github.com)|140.82.114.3|:443&#8230; connected. HTTP request sent, awaiting response&#8230; 302 Found Location: <a href=\"https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/95b45712-6e2f-4a52-96b1-7d88578345fc?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074502Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=bbcf001659c13b2799db0c9e208aac953c48ab98ba18d9fc0bd1c65661bed690&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partab&amp;response-content-type=application%2Foctet-stream\" target=\"_blank\" rel=\"noreferrer noopener\">https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/95b45712-6e2f-4a52-96b1-7d88578345fc?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074502Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=bbcf001659c13b2799db0c9e208aac953c48ab98ba18d9fc0bd1c65661bed690&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partab&amp;response-content-type=application%2Foctet-stream<\/a> [following] &#8211;2024-04-06 07:45:02&#8211; <a href=\"https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/95b45712-6e2f-4a52-96b1-7d88578345fc?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074502Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=bbcf001659c13b2799db0c9e208aac953c48ab98ba18d9fc0bd1c65661bed690&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partab&amp;response-content-type=application%2Foctet-stream\" target=\"_blank\" rel=\"noreferrer noopener\">https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/95b45712-6e2f-4a52-96b1-7d88578345fc?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074502Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=bbcf001659c13b2799db0c9e208aac953c48ab98ba18d9fc0bd1c65661bed690&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partab&amp;response-content-type=application%2Foctet-stream<\/a> Resolving objects.githubusercontent.com (objects.githubusercontent.com)&#8230; 185.199.108.133, 185.199.109.133, 185.199.110.133, &#8230; Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.108.133|:443&#8230; connected. HTTP request sent, awaiting response&#8230; 200 OK Length: 1560784333 (1.5G) [application\/octet-stream] Saving to: \u2018Dataset.tar.gz.partab\u2019 Dataset.tar.gz.part 100%[===================&gt;] 1.45G 185MB\/s in 8.3s 2024-04-06 07:45:10 (180 MB\/s) &#8211; \u2018Dataset.tar.gz.partab\u2019 saved [1560784333\/1560784333] &#8211;2024-04-06 07:45:10&#8211; <a href=\"https:\/\/github.com\/googly-mingto\/ML2023HW4\/releases\/download\/data\/Dataset.tar.gz.partac\" target=\"_blank\" rel=\"noreferrer noopener\">https:\/\/github.com\/googly-mingto\/ML2023HW4\/releases\/download\/data\/Dataset.tar.gz.partac<\/a> Resolving github.com (github.com)&#8230; 140.82.113.4 Connecting to github.com (github.com)|140.82.113.4|:443&#8230; connected. HTTP request sent, awaiting response&#8230; 302 Found Location: <a href=\"https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/0c9d42d3-95b7-4ca4-b57c-ab1a66a5564d?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074511Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=5c7261a7bf2e44dc70b11060edd9b7ab404f111d200ac6cb8a5a596f279d407c&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partac&amp;response-content-type=application%2Foctet-stream\" target=\"_blank\" rel=\"noreferrer noopener\">https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/0c9d42d3-95b7-4ca4-b57c-ab1a66a5564d?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074511Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=5c7261a7bf2e44dc70b11060edd9b7ab404f111d200ac6cb8a5a596f279d407c&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partac&amp;response-content-type=application%2Foctet-stream<\/a> [following] &#8211;2024-04-06 07:45:11&#8211; <a href=\"https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/0c9d42d3-95b7-4ca4-b57c-ab1a66a5564d?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074511Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=5c7261a7bf2e44dc70b11060edd9b7ab404f111d200ac6cb8a5a596f279d407c&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partac&amp;response-content-type=application%2Foctet-stream\" target=\"_blank\" rel=\"noreferrer noopener\">https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/0c9d42d3-95b7-4ca4-b57c-ab1a66a5564d?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074511Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=5c7261a7bf2e44dc70b11060edd9b7ab404f111d200ac6cb8a5a596f279d407c&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partac&amp;response-content-type=application%2Foctet-stream<\/a> Resolving objects.githubusercontent.com (objects.githubusercontent.com)&#8230; 185.199.110.133, 185.199.108.133, 185.199.111.133, &#8230; Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.110.133|:443&#8230; connected. HTTP request sent, awaiting response&#8230; 200 OK Length: 1560784333 (1.5G) [application\/octet-stream] Saving to: \u2018Dataset.tar.gz.partac\u2019 Dataset.tar.gz.part 100%[===================&gt;] 1.45G 68.6MB\/s in 22s 2024-04-06 07:45:33 (67.7 MB\/s) &#8211; \u2018Dataset.tar.gz.partac\u2019 saved [1560784333\/1560784333] &#8211;2024-04-06 07:45:33&#8211; <a href=\"https:\/\/github.com\/googly-mingto\/ML2023HW4\/releases\/download\/data\/Dataset.tar.gz.partad\" target=\"_blank\" rel=\"noreferrer noopener\">https:\/\/github.com\/googly-mingto\/ML2023HW4\/releases\/download\/data\/Dataset.tar.gz.partad<\/a> Resolving github.com (github.com)&#8230; 140.82.113.4 Connecting to github.com (github.com)|140.82.113.4|:443&#8230; connected. HTTP request sent, awaiting response&#8230; 302 Found Location: <a href=\"https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/0ee11da6-8c96-4463-b084-cea8f95d26e9?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074533Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=bedbbb164433c7d64178e9133717d7ae79ff24ff77722a12027b04b44d2612ce&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partad&amp;response-content-type=application%2Foctet-stream\" target=\"_blank\" rel=\"noreferrer noopener\">https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/0ee11da6-8c96-4463-b084-cea8f95d26e9?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074533Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=bedbbb164433c7d64178e9133717d7ae79ff24ff77722a12027b04b44d2612ce&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partad&amp;response-content-type=application%2Foctet-stream<\/a> [following] &#8211;2024-04-06 07:45:33&#8211; <a href=\"https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/0ee11da6-8c96-4463-b084-cea8f95d26e9?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074533Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=bedbbb164433c7d64178e9133717d7ae79ff24ff77722a12027b04b44d2612ce&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partad&amp;response-content-type=application%2Foctet-stream\" target=\"_blank\" rel=\"noreferrer noopener\">https:\/\/objects.githubusercontent.com\/github-production-release-asset-2e65be\/606989982\/0ee11da6-8c96-4463-b084-cea8f95d26e9?X-Amz-Algorithm=AWS4-HMAC-SHA256&amp;X-Amz-Credential=AKIAVCODYLSA53PQK4ZA%2F20240406%2Fus-east-1%2Fs3%2Faws4_request&amp;X-Amz-Date=20240406T074533Z&amp;X-Amz-Expires=300&amp;X-Amz-Signature=bedbbb164433c7d64178e9133717d7ae79ff24ff77722a12027b04b44d2612ce&amp;X-Amz-SignedHeaders=host&amp;actor_id=0&amp;key_id=0&amp;repo_id=606989982&amp;response-content-disposition=attachment%3B%20filename%3DDataset.tar.gz.partad&amp;response-content-type=application%2Foctet-stream<\/a> Resolving objects.githubusercontent.com (objects.githubusercontent.com)&#8230; 185.199.108.133, 185.199.111.133, 185.199.109.133, &#8230; Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.108.133|:443&#8230; connected. HTTP request sent, awaiting response&#8230; 200 OK Length: 1560784336 (1.5G) [application\/octet-stream] Saving to: \u2018Dataset.tar.gz.partad\u2019 Dataset.tar.gz.part 100%[===================&gt;] 1.45G 69.6MB\/s in 22s 2024-04-06 07:45:57 (68.7 MB\/s) &#8211; \u2018Dataset.tar.gz.partad\u2019 saved [1560784336\/1560784336] tar: Ignoring unknown extended header keyword &#8216;LIBARCHIVE.xattr.com.apple.macl&#8217;<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>!tar zxf Dataset.tar.gz<\/code><\/pre>\n\n\n\n<p>tar (child): Dataset.tar.gz: Cannot open: No such file or directory tar (child): Error is not recoverable: exiting now tar: Child returned status 2 tar: Error is not recoverable: exiting now<\/p>\n\n\n\n<p>\u56fa\u5b9a\u968f\u673a\u79cd\u5b50<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import numpy as np\nimport torch\nimport random\n\ndef set_seed(seed):\n    np.random.seed(seed)\n    random.seed(seed)\n    torch.manual_seed(seed)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed(seed)\n        torch.cuda.manual_seed_all(seed)\n    torch.backends.cudnn.benchmark = False\n    torch.backends.cudnn.deterministic = True\n\nset_seed(87)<\/code><\/pre>\n\n\n\n<p>Dataset<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import os\nimport json\nimport torch\nimport random\nfrom pathlib import Path\nfrom torch.utils.data import Dataset\nfrom torch.nn.utils.rnn import pad_sequence\n\n\nclass myDataset(Dataset):\n\tdef __init__(self, data_dir, segment_len=128):\n\t\tself.data_dir = data_dir\n\t\tself.segment_len = segment_len\n\n\t\t# Load the mapping from speaker neme to their corresponding id.\n\t\tmapping_path = Path(data_dir) \/ \"mapping.json\"\n\t\tmapping = json.load(mapping_path.open())\n\t\tself.speaker2id = mapping&#91;\"speaker2id\"]\n\n\t\t# Load metadata of training data.\n\t\tmetadata_path = Path(data_dir) \/ \"metadata.json\"\n\t\tmetadata = json.load(open(metadata_path))&#91;\"speakers\"]\n\n\t\t# Get the total number of speaker.\n\t\tself.speaker_num = len(metadata.keys())\n\t\tself.data = &#91;]\n\t\tfor speaker in metadata.keys():\n\t\t\tfor utterances in metadata&#91;speaker]:\n\t\t\t\tself.data.append(&#91;utterances&#91;\"feature_path\"], self.speaker2id&#91;speaker]])\n\n\tdef __len__(self):\n\t\t\treturn len(self.data)\n\n\tdef __getitem__(self, index):\n\t\tfeat_path, speaker = self.data&#91;index]\n\t\t# Load preprocessed mel-spectrogram.\n\t\tmel = torch.load(os.path.join(self.data_dir, feat_path))\n\n\t\t# Segmemt mel-spectrogram into \"segment_len\" frames.\n\t\tif len(mel) &gt; self.segment_len:\n\t\t\t# Randomly get the starting point of the segment.\n\t\t\tstart = random.randint(0, len(mel) - self.segment_len)\n\t\t\t# Get a segment with \"segment_len\" frames.\n\t\t\tmel = torch.FloatTensor(mel&#91;start:start+self.segment_len])\n\t\telse:\n\t\t\tmel = torch.FloatTensor(mel)\n\t\t# Turn the speaker id into long for computing loss later.\n\t\tspeaker = torch.FloatTensor(&#91;speaker]).long()\n\t\treturn mel, speaker\n\n\tdef get_speaker_number(self):\n\t\treturn self.speaker_num<\/code><\/pre>\n\n\n\n<p>Download<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch\nfrom torch.utils.data import DataLoader, random_split\nfrom torch.nn.utils.rnn import pad_sequence\n\n\ndef collate_batch(batch):\n\t# Process features within a batch.\n\t\"\"\"Collate a batch of data.\"\"\"\n\tmel, speaker = zip(*batch)\n\t# Because we train the model batch by batch, we need to pad the features in the same batch to make their lengths the same.\n\tmel = pad_sequence(mel, batch_first=True, padding_value=-20)    # pad log 10^(-20) which is very small value.\n\t# mel: (batch size, length, 40)\n\treturn mel, torch.FloatTensor(speaker).long()\n\n\ndef get_dataloader(data_dir, batch_size, n_workers):\n\t\"\"\"Generate dataloader\"\"\"\n\tdataset = myDataset(data_dir)\n\tspeaker_num = dataset.get_speaker_number()\n\t# Split dataset into training dataset and validation dataset\n\ttrainlen = int(0.9 * len(dataset))\n\tlengths = &#91;trainlen, len(dataset) - trainlen]\n\ttrainset, validset = random_split(dataset, lengths)\n\n\ttrain_loader = DataLoader(\n\t\ttrainset,\n\t\tbatch_size=batch_size,\n\t\tshuffle=True,\n\t\tdrop_last=True,\n\t\tnum_workers=n_workers,\n\t\tpin_memory=True,\n\t\tcollate_fn=collate_batch,\n\t)\n\tvalid_loader = DataLoader(\n\t\tvalidset,\n\t\tbatch_size=batch_size,\n\t\tnum_workers=n_workers,\n\t\tdrop_last=True,\n\t\tpin_memory=True,\n\t\tcollate_fn=collate_batch,\n\t)\n\n\treturn train_loader, valid_loader, speaker_num<\/code><\/pre>\n\n\n\n<p>\u6a21\u578b<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Classifier(nn.Module):\n\tdef __init__(self, d_model=80, n_spks=600, dropout=0.1):\n\t\tsuper().__init__()\n\t\t# Project the dimension of features from that of input into d_model.\n\t\tself.prenet = nn.Linear(40, d_model)\n\t\t# TODO:\n\t\t#   Change Transformer to Conformer.\n\t\t#   https:\/\/arxiv.org\/abs\/2005.08100\n\t\tself.encoder_layer = nn.TransformerEncoderLayer(\n\t\t\td_model=d_model, dim_feedforward=256, nhead=2\n\t\t)\n\t\t# self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=2)\n\n\t\t# Project the the dimension of features from d_model into speaker nums.\n\t\tself.pred_layer = nn.Sequential(\n\t\t\tnn.Linear(d_model, d_model),\n\t\t\tnn.Sigmoid(),\n\t\t\tnn.Linear(d_model, n_spks),\n\t\t)\n\n\tdef forward(self, mels):\n\t\t\"\"\"\n\t\targs:\n\t\t\tmels: (batch size, length, 40)\n\t\treturn:\n\t\t\tout: (batch size, n_spks)\n\t\t\"\"\"\n\t\t# out: (batch size, length, d_model)\n\t\tout = self.prenet(mels)\n\t\t# out: (length, batch size, d_model)\n\t\tout = out.permute(1, 0, 2)\n\t\t# The encoder layer expect features in the shape of (length, batch size, d_model).\n\t\tout = self.encoder_layer(out)\n\t\t# out: (batch size, length, d_model)\n\t\tout = out.transpose(0, 1)\n\t\t# mean pooling\n\t\tstats = out.mean(dim=1)\n\n\t\t# out: (batch, n_spks)\n\t\tout = self.pred_layer(stats)\n\t\treturn out<\/code><\/pre>\n\n\n\n<p>\u5b66\u4e60\u7387\u8868<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import math\n\nimport torch\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import LambdaLR\n\n\ndef get_cosine_schedule_with_warmup(\n\toptimizer: Optimizer,\n\tnum_warmup_steps: int,\n\tnum_training_steps: int,\n\tnum_cycles: float = 0.5,\n\tlast_epoch: int = -1,\n):\n\t\"\"\"\n\tCreate a schedule with a learning rate that decreases following the values of the cosine function between the\n\tinitial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the\n\tinitial lr set in the optimizer.\n\n\tArgs:\n\t\toptimizer (:class:`~torch.optim.Optimizer`):\n\t\tThe optimizer for which to schedule the learning rate.\n\t\tnum_warmup_steps (:obj:`int`):\n\t\tThe number of steps for the warmup phase.\n\t\tnum_training_steps (:obj:`int`):\n\t\tThe total number of training steps.\n\t\tnum_cycles (:obj:`float`, `optional`, defaults to 0.5):\n\t\tThe number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0\n\t\tfollowing a half-cosine).\n\t\tlast_epoch (:obj:`int`, `optional`, defaults to -1):\n\t\tThe index of the last epoch when resuming training.\n\n\tReturn:\n\t\t:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.\n\t\"\"\"\n\tdef lr_lambda(current_step):\n\t\t# Warmup\n\t\tif current_step &lt; num_warmup_steps:\n\t\t\treturn float(current_step) \/ float(max(1, num_warmup_steps))\n\t\t# decadence\n\t\tprogress = float(current_step - num_warmup_steps) \/ float(\n\t\t\tmax(1, num_training_steps - num_warmup_steps)\n\t\t)\n\t\treturn max(\n\t\t\t0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))\n\t\t)\n\n\treturn LambdaLR(optimizer, lr_lambda, last_epoch)<\/code><\/pre>\n\n\n\n<p>\u6a21\u578b\u51fd\u6570<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import torch\n\n\ndef model_fn(batch, model, criterion, device):\n\t\"\"\"Forward a batch through the model.\"\"\"\n\n\tmels, labels = batch\n\tmels = mels.to(device)\n\tlabels = labels.to(device)\n\n\touts = model(mels)\n\n\tloss = criterion(outs, labels)\n\n\t# Get the speaker id with highest probability.\n\tpreds = outs.argmax(1)\n\t# Compute accuracy.\n\taccuracy = torch.mean((preds == labels).float())\n\n\treturn loss, accuracy<\/code><\/pre>\n\n\n\n<p>\u9a8c\u8bc1<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>from tqdm import tqdm\nimport torch\n\n\ndef valid(dataloader, model, criterion, device):\n\t\"\"\"Validate on validation set.\"\"\"\n\n\tmodel.eval()\n\trunning_loss = 0.0\n\trunning_accuracy = 0.0\n\tpbar = tqdm(total=len(dataloader.dataset), ncols=0, desc=\"Valid\", unit=\" uttr\")\n\n\tfor i, batch in enumerate(dataloader):\n\t\twith torch.no_grad():\n\t\t\tloss, accuracy = model_fn(batch, model, criterion, device)\n\t\t\trunning_loss += loss.item()\n\t\t\trunning_accuracy += accuracy.item()\n\n\t\tpbar.update(dataloader.batch_size)\n\t\tpbar.set_postfix(\n\t\t\tloss=f\"{running_loss \/ (i+1):.2f}\",\n\t\t\taccuracy=f\"{running_accuracy \/ (i+1):.2f}\",\n\t\t)\n\n\tpbar.close()\n\tmodel.train()\n\n\treturn running_accuracy \/ len(dataloader)<\/code><\/pre>\n\n\n\n<p>\u4e3b\u51fd\u6570<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>from tqdm import tqdm\n\nimport torch\nimport torch.nn as nn\nfrom torch.optim import AdamW\nfrom torch.utils.data import DataLoader, random_split\n\n\ndef parse_args():\n\t\"\"\"arguments\"\"\"\n\tconfig = {\n\t\t\"data_dir\": \".\/Dataset\",\n\t\t\"save_path\": \"model.ckpt\",\n\t\t\"batch_size\": 64,\n\t\t\"n_workers\": 8,\n\t\t\"valid_steps\": 2000,\n\t\t\"warmup_steps\": 1000,\n\t\t\"save_steps\": 10000,\n\t\t\"total_steps\": 70000,\n\t}\n\n\treturn config\n\n\ndef main(\n\tdata_dir,\n\tsave_path,\n\tbatch_size,\n\tn_workers,\n\tvalid_steps,\n\twarmup_steps,\n\ttotal_steps,\n\tsave_steps,\n):\n\t\"\"\"Main function.\"\"\"\n\tdevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\tprint(f\"&#91;Info]: Use {device} now!\")\n\n\ttrain_loader, valid_loader, speaker_num = get_dataloader(data_dir, batch_size, n_workers)\n\ttrain_iterator = iter(train_loader)\n\tprint(f\"&#91;Info]: Finish loading data!\",flush = True)\n\n\tmodel = Classifier(n_spks=speaker_num).to(device)\n\tcriterion = nn.CrossEntropyLoss()\n\toptimizer = AdamW(model.parameters(), lr=1e-3)\n\tscheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps)\n\tprint(f\"&#91;Info]: Finish creating model!\",flush = True)\n\n\tbest_accuracy = -1.0\n\tbest_state_dict = None\n\n\tpbar = tqdm(total=valid_steps, ncols=0, desc=\"Train\", unit=\" step\")\n\n\tfor step in range(total_steps):\n\t\t# Get data\n\t\ttry:\n\t\t\tbatch = next(train_iterator)\n\t\texcept StopIteration:\n\t\t\ttrain_iterator = iter(train_loader)\n\t\t\tbatch = next(train_iterator)\n\n\t\tloss, accuracy = model_fn(batch, model, criterion, device)\n\t\tbatch_loss = loss.item()\n\t\tbatch_accuracy = accuracy.item()\n\n\t\t# Updata model\n\t\tloss.backward()\n\t\toptimizer.step()\n\t\tscheduler.step()\n\t\toptimizer.zero_grad()\n\n\t\t# Log\n\t\tpbar.update()\n\t\tpbar.set_postfix(\n\t\t\tloss=f\"{batch_loss:.2f}\",\n\t\t\taccuracy=f\"{batch_accuracy:.2f}\",\n\t\t\tstep=step + 1,\n\t\t)\n\n\t\t# Do validation\n\t\tif (step + 1) % valid_steps == 0:\n\t\t\tpbar.close()\n\n\t\t\tvalid_accuracy = valid(valid_loader, model, criterion, device)\n\n\t\t\t# keep the best model\n\t\t\tif valid_accuracy &gt; best_accuracy:\n\t\t\t\tbest_accuracy = valid_accuracy\n\t\t\t\tbest_state_dict = model.state_dict()\n\n\t\t\tpbar = tqdm(total=valid_steps, ncols=0, desc=\"Train\", unit=\" step\")\n\n\t\t# Save the best model so far.\n\t\tif (step + 1) % save_steps == 0 and best_state_dict is not None:\n\t\t\ttorch.save(best_state_dict, save_path)\n\t\t\tpbar.write(f\"Step {step + 1}, best model saved. (accuracy={best_accuracy:.4f})\")\n\n\tpbar.close()\n\n\nif __name__ == \"__main__\":\n\tmain(**parse_args())<\/code><\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">[Info]: Use cpu now!\n<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Train:   0% 0\/2000 [24:09&lt;?, ? step\/s]\nTrain:   0% 0\/2000 [22:12&lt;?, ? step\/s]\n<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">[Info]: Finish loading data!\n[Info]: Finish creating model!\n<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Train: 100% 2000\/2000 [08:21&lt;00:00,  3.99 step\/s, accuracy=0.05, loss=4.89, step=2000]\nValid:  99% 5632\/5667 [00:16&lt;00:00, 345.07 uttr\/s, accuracy=0.06, loss=4.94]\nTrain: 100% 2000\/2000 [08:18&lt;00:00,  4.02 step\/s, accuracy=0.16, loss=4.23, step=4000]\nValid:  99% 5632\/5667 [00:13&lt;00:00, 429.97 uttr\/s, accuracy=0.15, loss=4.22]\nTrain: 100% 2000\/2000 [08:07&lt;00:00,  4.10 step\/s, accuracy=0.20, loss=3.84, step=6000]\nValid:  99% 5632\/5667 [00:10&lt;00:00, 523.66 uttr\/s, accuracy=0.20, loss=3.82]\nTrain: 100% 2000\/2000 [08:20&lt;00:00,  4.00 step\/s, accuracy=0.34, loss=3.35, step=8000]\nValid:  99% 5632\/5667 [00:11&lt;00:00, 506.66 uttr\/s, accuracy=0.26, loss=3.50]\nTrain: 100% 2000\/2000 [08:16&lt;00:00,  4.03 step\/s, accuracy=0.36, loss=3.18, step=1e+4]\nValid:  99% 5632\/5667 [00:07&lt;00:00, 732.66 uttr\/s, accuracy=0.29, loss=3.26] \n\n\nTrain:   0% 0\/2000 [00:00&lt;?, ? step\/s]\n\nTrain:   0% 0\/2000 [1:04:43&lt;?, ? step\/s]<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Step 10000, best model saved. (accuracy=0.2919)\n<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Train: 100% 2000\/2000 [08:14&lt;00:00,  4.05 step\/s, accuracy=0.41, loss=3.03, step=12000]\nValid:  99% 5632\/5667 [00:10&lt;00:00, 547.19 uttr\/s, accuracy=0.32, loss=3.15]\nTrain: 100% 2000\/2000 [08:03&lt;00:00,  4.13 step\/s, accuracy=0.41, loss=2.50, step=14000]\nValid:  99% 5632\/5667 [00:09&lt;00:00, 576.39 uttr\/s, accuracy=0.35, loss=2.97]\nTrain: 100% 2000\/2000 [08:18&lt;00:00,  4.01 step\/s, accuracy=0.44, loss=2.55, step=16000]\nValid:  99% 5632\/5667 [00:13&lt;00:00, 412.67 uttr\/s, accuracy=0.38, loss=2.84]\nTrain: 100% 2000\/2000 [08:11&lt;00:00,  4.07 step\/s, accuracy=0.45, loss=2.55, step=18000]\nValid:  99% 5632\/5667 [00:14&lt;00:00, 399.51 uttr\/s, accuracy=0.40, loss=2.73]\nTrain: 100% 2000\/2000 [08:21&lt;00:00,  3.99 step\/s, accuracy=0.36, loss=2.65, step=2e+4]\nValid:  99% 5632\/5667 [00:09&lt;00:00, 591.63 uttr\/s, accuracy=0.42, loss=2.62]\n\n\nTrain:   0% 0\/2000 [00:00&lt;?, ? step\/s]\n\nTrain:   0% 0\/2000 [1:46:50&lt;?, ? step\/s]<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Step 20000, best model saved. (accuracy=0.4213)\n<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Train: 100% 2000\/2000 [08:17&lt;00:00,  4.02 step\/s, accuracy=0.45, loss=2.40, step=22000]\nValid:  99% 5632\/5667 [00:10&lt;00:00, 546.98 uttr\/s, accuracy=0.45, loss=2.52]\nTrain: 100% 2000\/2000 [08:14&lt;00:00,  4.04 step\/s, accuracy=0.48, loss=2.24, step=24000]\nValid:  99% 5632\/5667 [00:12&lt;00:00, 464.67 uttr\/s, accuracy=0.45, loss=2.51]\nTrain: 100% 2000\/2000 [08:10&lt;00:00,  4.08 step\/s, accuracy=0.50, loss=2.19, step=26000]\nValid:  99% 5632\/5667 [00:09&lt;00:00, 612.79 uttr\/s, accuracy=0.46, loss=2.42] \nTrain: 100% 2000\/2000 [08:19&lt;00:00,  4.00 step\/s, accuracy=0.52, loss=2.19, step=28000]\nValid:  99% 5632\/5667 [00:10&lt;00:00, 543.84 uttr\/s, accuracy=0.47, loss=2.36] \nTrain: 100% 2000\/2000 [08:24&lt;00:00,  3.96 step\/s, accuracy=0.48, loss=2.03, step=3e+4]\nValid:  99% 5632\/5667 [00:14&lt;00:00, 380.97 uttr\/s, accuracy=0.49, loss=2.33]\n\n\nTrain:   0% 0\/2000 [00:00&lt;?, ? step\/s]\n\nTrain:   0% 0\/2000 [2:29:14&lt;?, ? step\/s]<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Step 30000, best model saved. (accuracy=0.4854)\n<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Train: 100% 2000\/2000 [08:21&lt;00:00,  3.99 step\/s, accuracy=0.44, loss=2.43, step=32000]\nValid:  99% 5632\/5667 [00:07&lt;00:00, 704.44 uttr\/s, accuracy=0.50, loss=2.26] \nTrain: 100% 2000\/2000 [08:17&lt;00:00,  4.02 step\/s, accuracy=0.66, loss=1.72, step=34000]\nValid:  99% 5632\/5667 [00:09&lt;00:00, 620.62 uttr\/s, accuracy=0.50, loss=2.24] \nTrain: 100% 2000\/2000 [08:20&lt;00:00,  4.00 step\/s, accuracy=0.61, loss=1.80, step=36000]\nValid:  99% 5632\/5667 [00:09&lt;00:00, 600.96 uttr\/s, accuracy=0.51, loss=2.19]\nTrain: 100% 2000\/2000 [08:13&lt;00:00,  4.06 step\/s, accuracy=0.64, loss=1.74, step=38000]\nValid:  99% 5632\/5667 [00:12&lt;00:00, 444.12 uttr\/s, accuracy=0.53, loss=2.13]\nTrain: 100% 2000\/2000 [08:20&lt;00:00,  3.99 step\/s, accuracy=0.70, loss=1.49, step=4e+4]\nValid:  99% 5632\/5667 [00:10&lt;00:00, 539.81 uttr\/s, accuracy=0.53, loss=2.09] \n\n\nTrain:   0% 0\/2000 [00:00&lt;?, ? step\/s]\n\nTrain:   0% 0\/2000 [3:11:36&lt;?, ? step\/s]<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Step 40000, best model saved. (accuracy=0.5300)\n<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Train: 100% 2000\/2000 [08:15&lt;00:00,  4.04 step\/s, accuracy=0.62, loss=1.76, step=42000]\nValid:  99% 5632\/5667 [00:13&lt;00:00, 417.85 uttr\/s, accuracy=0.53, loss=2.06]\nTrain: 100% 2000\/2000 [08:18&lt;00:00,  4.02 step\/s, accuracy=0.56, loss=2.04, step=44000]\nValid:  99% 5632\/5667 [00:11&lt;00:00, 495.57 uttr\/s, accuracy=0.55, loss=2.04] \nTrain: 100% 2000\/2000 [08:13&lt;00:00,  4.05 step\/s, accuracy=0.53, loss=1.97, step=46000]\nValid:  99% 5632\/5667 [00:08&lt;00:00, 670.28 uttr\/s, accuracy=0.55, loss=1.98]\nTrain: 100% 2000\/2000 [08:03&lt;00:00,  4.14 step\/s, accuracy=0.56, loss=1.95, step=48000]\nValid:  99% 5632\/5667 [00:16&lt;00:00, 346.02 uttr\/s, accuracy=0.56, loss=1.99]\nTrain: 100% 2000\/2000 [08:08&lt;00:00,  4.09 step\/s, accuracy=0.59, loss=1.92, step=5e+4]\nValid:  99% 5632\/5667 [00:10&lt;00:00, 554.63 uttr\/s, accuracy=0.56, loss=2.00] \n\n\nTrain:   0% 0\/2000 [00:00&lt;?, ? step\/s]\n\nTrain:   0% 0\/2000 [3:53:35&lt;?, ? step\/s]<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Step 50000, best model saved. (accuracy=0.5584)\n<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Train: 100% 2000\/2000 [08:16&lt;00:00,  4.03 step\/s, accuracy=0.58, loss=1.55, step=52000]\nValid:  99% 5632\/5667 [00:11&lt;00:00, 501.86 uttr\/s, accuracy=0.57, loss=1.94]\nTrain: 100% 2000\/2000 [08:17&lt;00:00,  4.02 step\/s, accuracy=0.48, loss=1.97, step=54000]\nValid:  99% 5632\/5667 [00:17&lt;00:00, 327.04 uttr\/s, accuracy=0.57, loss=1.93]\nTrain: 100% 2000\/2000 [08:26&lt;00:00,  3.95 step\/s, accuracy=0.81, loss=1.05, step=56000]\nValid:  99% 5632\/5667 [00:07&lt;00:00, 712.35 uttr\/s, accuracy=0.57, loss=1.93] \nTrain: 100% 2000\/2000 [08:18&lt;00:00,  4.01 step\/s, accuracy=0.61, loss=1.72, step=58000]\nValid:  99% 5632\/5667 [00:09&lt;00:00, 567.86 uttr\/s, accuracy=0.57, loss=1.94]\nTrain: 100% 2000\/2000 [08:19&lt;00:00,  4.00 step\/s, accuracy=0.73, loss=1.19, step=6e+4]\nValid:  99% 5632\/5667 [00:17&lt;00:00, 313.22 uttr\/s, accuracy=0.58, loss=1.93]\n\n\nTrain:   0% 0\/2000 [00:00&lt;?, ? step\/s]\n\nTrain:   0% 0\/2000 [4:36:19&lt;?, ? step\/s]<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Step 60000, best model saved. (accuracy=0.5785)\n<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Train: 100% 2000\/2000 [08:08&lt;00:00,  4.09 step\/s, accuracy=0.67, loss=1.34, step=62000]\nValid:  99% 5632\/5667 [00:10&lt;00:00, 517.80 uttr\/s, accuracy=0.59, loss=1.88] \nTrain: 100% 2000\/2000 [08:17&lt;00:00,  4.02 step\/s, accuracy=0.61, loss=1.61, step=64000]\nValid:  99% 5632\/5667 [00:09&lt;00:00, 624.30 uttr\/s, accuracy=0.58, loss=1.91] \nTrain: 100% 2000\/2000 [08:17&lt;00:00,  4.02 step\/s, accuracy=0.61, loss=1.64, step=66000]\nValid:  99% 5632\/5667 [00:11&lt;00:00, 511.24 uttr\/s, accuracy=0.57, loss=1.91]\nTrain: 100% 2000\/2000 [08:21&lt;00:00,  3.99 step\/s, accuracy=0.72, loss=1.20, step=68000]\nValid:  99% 5632\/5667 [00:08&lt;00:00, 684.44 uttr\/s, accuracy=0.58, loss=1.89]\nTrain: 100% 2000\/2000 [08:17&lt;00:00,  4.02 step\/s, accuracy=0.64, loss=1.54, step=7e+4]\nValid:  99% 5632\/5667 [00:08&lt;00:00, 638.93 uttr\/s, accuracy=0.58, loss=1.89] \n\n\nTrain:   0% 0\/2000 [00:00&lt;?, ? step\/s]\n\nTrain:   0% 0\/2000 [00:00&lt;?, ? step\/s]\n<\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">Step 70000, best model saved. (accuracy=0.5881)<\/pre>\n\n\n\n<p>\u63a8\u7406<\/p>\n\n\n\n<p>\u63a8\u7406\u7684\u6570\u636e\u96c6<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import os\nimport json\nimport torch\nfrom pathlib import Path\nfrom torch.utils.data import Dataset\n\n\nclass InferenceDataset(Dataset):\n\tdef __init__(self, data_dir):\n\t\ttestdata_path = Path(data_dir) \/ \"testdata.json\"\n\t\tmetadata = json.load(testdata_path.open())\n\t\tself.data_dir = data_dir\n\t\tself.data = metadata&#91;\"utterances\"]\n\n\tdef __len__(self):\n\t\treturn len(self.data)\n\n\tdef __getitem__(self, index):\n\t\tutterance = self.data&#91;index]\n\t\tfeat_path = utterance&#91;\"feature_path\"]\n\t\tmel = torch.load(os.path.join(self.data_dir, feat_path))\n\n\t\treturn feat_path, mel\n\n\ndef inference_collate_batch(batch):\n\t\"\"\"Collate a batch of data.\"\"\"\n\tfeat_paths, mels = zip(*batch)\n\n\treturn feat_paths, torch.stack(mels)<\/code><\/pre>\n\n\n\n<p>\u4e3b\u51fd\u6570\u7684\u63a8\u7406<\/p>\n\n\n\n<pre class=\"wp-block-code\"><code>import json\nimport csv\nfrom pathlib import Path\nfrom tqdm.notebook import tqdm\n\nimport torch\nfrom torch.utils.data import DataLoader\n\ndef parse_args():\n\t\"\"\"arguments\"\"\"\n\tconfig = {\n\t\t\"data_dir\": \".\/Dataset\",\n\t\t\"model_path\": \".\/model.ckpt\",\n\t\t\"output_path\": \".\/output.csv\",\n\t}\n\n\treturn config\n\n\ndef main(\n\tdata_dir,\n\tmodel_path,\n\toutput_path,\n):\n\t\"\"\"Main function.\"\"\"\n\tdevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\tprint(f\"&#91;Info]: Use {device} now!\")\n\n\tmapping_path = Path(data_dir) \/ \"mapping.json\"\n\tmapping = json.load(mapping_path.open())\n\n\tdataset = InferenceDataset(data_dir)\n\tdataloader = DataLoader(\n\t\tdataset,\n\t\tbatch_size=1,\n\t\tshuffle=False,\n\t\tdrop_last=False,\n\t\tnum_workers=8,\n\t\tcollate_fn=inference_collate_batch,\n\t)\n\tprint(f\"&#91;Info]: Finish loading data!\",flush = True)\n\n\tspeaker_num = len(mapping&#91;\"id2speaker\"])\n\tmodel = Classifier(n_spks=speaker_num).to(device)\n\tmodel.load_state_dict(torch.load(model_path))\n\tmodel.eval()\n\tprint(f\"&#91;Info]: Finish creating model!\",flush = True)\n\n\tresults = &#91;&#91;\"Id\", \"Category\"]]\n\tfor feat_paths, mels in tqdm(dataloader):\n\t\twith torch.no_grad():\n\t\t\tmels = mels.to(device)\n\t\t\touts = model(mels)\n\t\t\tpreds = outs.argmax(1).cpu().numpy()\n\t\t\tfor feat_path, pred in zip(feat_paths, preds):\n\t\t\t\tresults.append(&#91;feat_path, mapping&#91;\"id2speaker\"]&#91;str(pred)]])\n\n\twith open(output_path, 'w', newline='') as csvfile:\n\t\twriter = csv.writer(csvfile)\n\t\twriter.writerows(results)\n\n\nif __name__ == \"__main__\":\n\tmain(**parse_args())<\/code><\/pre>\n\n\n\n<pre class=\"wp-block-preformatted\">[Info]: Use cpu now!\n[Info]: Finish loading data!\n[Info]: Finish creating model!\n<\/pre>\n\n\n\n<p>100%<\/p>\n\n\n\n<p>\u20078000\/8000\u2007[02:10&lt;00:00,\u200753.94it\/s]<\/p>\n","protected":false},"excerpt":{"rendered":"<p>\u3010HW4\u3011Self attention0.1\u674e\u5b8f\u6bc52021\/2022\u6625\u673a\u5668\u5b66\u4e60\u8bfe\u7a0b\u7b14\u8bb0EP14(P45-P47 [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":0,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[6],"tags":[15,3,7,9,8],"class_list":["post-547","post","type-post","status-publish","format-standard","hentry","category-lhyjqxxbj","tag-homework","tag-xxbj","tag-jjxx","tag-lhy","tag-deeplearning"],"_links":{"self":[{"href":"https:\/\/tobykskgd.life\/index.php\/wp-json\/wp\/v2\/posts\/547","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/tobykskgd.life\/index.php\/wp-json\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/tobykskgd.life\/index.php\/wp-json\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/tobykskgd.life\/index.php\/wp-json\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/tobykskgd.life\/index.php\/wp-json\/wp\/v2\/comments?post=547"}],"version-history":[{"count":3,"href":"https:\/\/tobykskgd.life\/index.php\/wp-json\/wp\/v2\/posts\/547\/revisions"}],"predecessor-version":[{"id":1879,"href":"https:\/\/tobykskgd.life\/index.php\/wp-json\/wp\/v2\/posts\/547\/revisions\/1879"}],"wp:attachment":[{"href":"https:\/\/tobykskgd.life\/index.php\/wp-json\/wp\/v2\/media?parent=547"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/tobykskgd.life\/index.php\/wp-json\/wp\/v2\/categories?post=547"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/tobykskgd.life\/index.php\/wp-json\/wp\/v2\/tags?post=547"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}