среда, 5 сентября 2018 г.

Конфигурируем Raspberry Pi3 в качестве среды разработки для ESP32

Я уже упоминал, что предпочитаю использовать Raspberry Pi в качестве среды разработки для своих "микроконтроллерных проектов". ESP32 - одна из популярных платформ, которая имеет на борту довольно мощный контроллер и Wi-Fi чип.

Обычно подготовка RPi заключается установке пакетов и стандартных репозиториев Raspbian, но, похоже в стандартных репозиториях соответствующих пакетов нет. К счастью, это не является большой проблемой, так как можно просто собрать всё недостающее из исходников.

Поскольку мне уже приходилось пару раз устанавливать необходимые для ESP32 инструменты в Raspberry Pi, мне снова и снова приходилось искать информацию на разных сайтах и компоновать всё воедино. Поэтому я решил собрать здесь пошаговую инструкцию, которая сработала для меня. Первым делом необходимо собрать toolchain -- набор программ для генерации прошивок из исходного кода (компилятор C/C++, ассемблер, линковщик и.т.д.). Инструкцию для Linux можно найти здесь, однако она может не сработать для RPi, так как по умолчанию для сборки будут задействованы все ядра ЦП, а для одновременной компиляции в несколько потоков у RPi может не хватить памяти, поэтому лучше воспользоваться немного модифицированной инструкцией отсюда, и собирать тулчейн в один поток. Итак:

Сборка занимает значительное время, поэтому если вы, как и я, используете RPi без мониора, подключаясь по SSH помощью  PuTTY или другого терминала, то предпочтительнее работать в среде screen. В этом случае процесс не остановится, если сетевое соединение между вашим компьютером и Raspberry PI оборвётся, или если вы закроете PuTTY (в противном случае процесс сборки убьётся как только разорвётся сессия SSH).

Итак, запускаем сессию screen

>screen

Обновляем систему и устанавливаем все необходимые зависимости

>sudo apt-get update

>sudo apt-get upgrade


>sudo apt-get install gawk gperf grep gettext automake bison flex texinfo help2man libtool libtool-bin git wget make libncurses-dev python3 python-serial python-dev python-pip

>sudo pip install pyserial

Подготавливаем директорию, в которой будем собирать toolchain:

>mkdir ~/esp
>cd ~/esp


Клонируем репозиторий с исходниками

>git clone -b xtensa-1.22.x https://github.com/espressif/crosstool-NG.git

Генерируем файлы для сборки

>cd crosstool-NG
>./bootstrap && ./configure --enable-local && make install
>./ct-ng xtensa-esp32-elf

И подправляем конфиг для сборки проекта

>nano ./.config

Следует найти строку CT_PARALLEL_JOBS=0 и заменить значение 0 на 1. Этим мы указываем, что при компиляции следует ограничиться одним потоком, а не использовать все доступные ядра. Тем самым мы значительно увеличиваем время сборки (на и без того небыстрой платформе), но при этом сборка не упадет из за нехватки памяти.
Инструкция рекомендует использовать редактор nano, но разумеется можно воспользоваться любым другим текстовым редактором, например, vim.

Запускаем сборку, предварительно запасшись терпением: у меня собиралось ночь и половину следующего дня:

>./ct-ng build build.1

Устанавливаем флаг executable на получившиеся бинарные файлы

>chmod -R u+w builds/xtensa-esp32-elf

Копируем полученные утилиты в директрию ~/esp . Следующая инструкция по созданию простейшего приложения для ESP32 написана с расчётом на то, что бинарники тулчейна лежат в директории ~/esp/xtensa-esp32-elf/bin. Кроме того, добавим путь к утилитам в переменную PATH.

>mv ~/esp/crosstool-NG/builds/xtensa-esp32-elf/ ~/esp
>export PATH="$PATH:$HOME/esp/xtensa-esp32-elf/bin"

Дальнейшее основано на инструкции с сайта espressif.com: Итак, теперь соберём прошивку, которая, как обычно, помигает светодиодом. Клонируем исходники API/библиотек/примеров для ESP32 с помощью git.

>cd ~/esp
>git clone --recursive https://github.com/espressif/esp-idf.git


Устанавливаем переменную окружения IDF_PATH=~/esp/esp-idf.

Для текущей сессии:

>export IDF_PATH=~/esp/esp-idf

Для будущих сессий следует добавить строчку "export IDF_PATH=~/esp/esp-idf" (без кавычек) в файл ~/.profile .

На всякий случай проверим, что все питоновские библиотеку установлены:

>sudo python -m pip install -r $IDF_PATH/requirements.txt

Копируем пример blink в директорию ~/esp.
>cp -r $IDF_PATH/examples/get-started/blink .
>cd ~/esp/blink
При подключении платы ESP32 к USB-разьёму Raspberry Pi в директории /dev должно появиться соответствующее устройство: в моём случае файл устройства был /dev/ttyAMA0. Следует использовать его в утилите прошивки. По идее должно быть достаточно выполнить команду

>make flash

Но у меня почему-то не сработало. Вместо этого пришлось отдельно собирать прект

>make

А затем прошивать
> ~/esp/esp-idf/components/esptool_py/esptool/esptool.py write_flash --flash_mode dio --flash_freq 40m --flash_size detect 0x10000 build/blink.bin


На моей версии платы ESP32 есть две кнопки: IO0 и EN. Для прошивки необходимо держать кнопку IO0 нажатой, и при этом однократно нажать кнопку EN (не отпуская IO0). Немного больше деталей тут.


пятница, 18 ноября 2016 г.

Исользование UART в микроконтроллере sam3x8e и Arduino Due

Интерфейс UART жизненно необходим при разработке устройств на основе микроконтроллеров. Даже если ваш проект не подразумевает подключение устройства к компьютеру (что проще всего сделать через интерфейс UART), этот интерфейс может сильно помочь в диагностике и отладке прошивок.

Микроконтроллер sam3x8e обладает одним портом UART и аж тремя портами USART. В этом посте мы не будем касаться особенностей USART, скажу лишь что USART является надмножеством UART,  так что будем считать, что в нашем распоряжении 4 порта UART.

Как и вся остальная периферия, UART/USART управляется регистрами, которые отображены в адресное пространство микроконтроллера. Управляющие регистры портов находятся по следующим адрксам:

ПортАдреса регистров
UART0x400E0800-0x400E0924
USART10x40098000-0x40098124
USART20x4009C000-0x4009C124
USART30x400A4000-0x400A4124

У каждого порта имеются такие регистры (R в колонке "доступ" обозначает доступность регистра для чтения, W - доступность регистра для записи).

Смещение Регистр UART Регистр USART Назначение Доступ
0x0000 UART_CR US_CR Control register W
0x0004 UART_MR US_MR Mode register RW
0x0008 UART_IER US_IER Interupt Enable register W
0x000C UART_IDR US_IDR Interupt Disable Register W
0x0010 UART_IMR US_IMR Interupt Mask Register R
0x0014 UART_SR US_SR Status Register R
0x0018 UART_RHR US_RHR Receive Holding Register R
0x001C UART_THR US_THR Transmit Holding Register W
0x0020 UART_BRGR US_BRGR Baud Rate Generator Register W
0x0100-0x0124 PDC Area PDC Area Registers controlling DMA


Для удобного доступа к этим ячейкам-регистрам в заголовочных файлах libsam определены такие структуры:


Видно, что функциаонально эквивалентные регистры USART и UART находятся по одинаковым смещениям, поэтому в функции, принимающие в качестве аргумента указатели на объект типа Uart можно передавать объекты типа Usart, и по идее в этом случае UART-подмножество функционала USART будет работать!

Обмен через Uart осуществляется примерно так: если передатчик готов принимать следующий байт, то в регистре статуса UART_SR выставляется бит UART_SR_TXRDY. Это значит, что в регистр UART_THR можно записать следующий передаваемый байт. После записи значения в ргистр UART_THR бит UART_SR_TXRDY в регистре UART_SR выставляется в 0 (передатчик не готов принимать следующее значение). Значение будет "храниться" в UART_THR до тех пор, пока сдвиговый регистр UART, обеспечивающий передачу битов предыдущего байта не освободится. Как только это произойдёт, значение UART_THR запищется в сдвиговый регистр UART, а бит UART_SR_TXRDY в регистре UART_SR выставляется в 1 (передатчик готов принимать следующее значение).

При приёме очередной байт также побитово записывается в сдвиговый регистр приёмника UART. Как только байт прочита полностью, он переносится в регистр UART_RHR, откуда его можно прочитать программно. При этом выставляется бит UART_SR_RXRDY в регистре UART_SR. Если к моменту заполнения сдвигового регистра, регистр UART_RHR не был прочитан, в регистре UART_SR выставляется бит UART_SR_OVRE (переполнение).

Инициализация порта Uart

Инициализацию UART можно осуществить примерно таким кодом:


Три метода работы с последовательным портом: поллинг, прерывания и прямой доступ к памяти

Пожалуй, самый простой способ общения с внешним устройством это просто прямой опрос содержимого регистров портов. Нужно просто достаточно часто проверять: пришёл/ушёл ли очередной байт. При приёме очередного байта нужно просто прочитать его из регистра UART_RHR и обработать. Соответственно при успешной отправке очередного байта, нужно положить в UART_THR очередной байт на отправку (если есть ещё данные для передачи). Такой периодический опрос устройства из основной программы называется английским словом поллинг -- polling.


Мне встречалась такая аналогия для поллинга: хозяин дома постоянно проверяет входную дверь: не пришли ли гости? Если от управляющей программы требуется одновременно делать что-то ещё кроме обмена данными с Uart-портом, то такой подход может усложнить разработку. Здесь на помощь приходят прерывания и прямой доступ к памяти (direct memory access, DMA).

Второй способ общения с портом заключается в том, что мы назначаем прерывание на различные события, происходящие в UART. Этими событиями могут быть либо приход нового байта через линию Rx либо отправка байта через Tx. В основной программе не надо думать о постоянном опросе регистров UART. Если вы хотите использовать прерывания при работе с UART, то при инициализации надо указать, на какие события мы хотим генерировать прерывание, выставив соответствующие биты в регистре UART_IER. В примере кода инициализации выше за это отвечает такая строчка:

UART->UART_IER = UART_IER_RXRDY | UART_IER_OVRE | UART_IER_FRAME;

Здесь мы включаем прерывание на приход очередного байта через Rx, ошибку переполнения и "ошибку фрейма" -- когда приёмник ожидает стоп-бит, но получает старотовый бит. Почему есть прерывание на приём, но нет прерывание на передачу? Просто мы пока что ничего не передаём. Как только начнём что-то отправлять, добавим прерывание на передачу байтов. Каждое устройство UART/USART имеет свой вектор прерывания, один на устройство. Разные события в одном и том же устройстве вызывают один и тот же обработчик прерывания. В этом обработчике надо проверить регистры устройства, чтобы понять, что за событие вызвало прерывание. Обработчик прерываний может выглядеть примерно так:





Механизм прямого доступа к памяти (DMA) позволяет ещё больше избавить центральный процессор микроконтроллера от необходимости отвлекаться на приём и передачу. Вместо этого усторойство (в нашем случае UART) само считывет из указанной области памяти данные байт за байтом (или слово за словом) и отправляет их на вывод, или наоборот последовательно складывает принятые данные в указанную область памяти. DMA не поддерживается некоторыми ARM Cortex M микроконтроллерами, например его нет в lpc1114. Но в ардуиновском sam3x8e он есть. В sam3x8e механизм DMA реализован единообразно для различных периферийных устройств (ADC, DAC, UART, USART и.т.д.). USART поддерживает дуплексный обмен с помощью DMA. Это значит, что в один и тот же момент времени данные могут передвавться из одного буфера в линию Tx, и приниматься через линию Rx в другой буфер. Механизм DMA эффективно сочетается с прерываниями: можно назначить прерывание на окончание приёма/передачи буфера и, например, назначить передачу по DMA следующего буфера. На самом деле в sam3x8e в мезанизме DMA аппаратно встроена "очередь" из двух буферов (об этом -- ниже), так что если вам надо передавать не больше двух раздельно лежащих буферов, то можно обойтись и без прерывание.

Итак, в таблице регистров UART/USART я указал, что по смещениям 0x100-0x124 лежит область PDC, регистров для управления DMA. Вот эти регистры:


Смещение Регистр UART Регистр USART Назначение Доступ
0x0100 UART_RPR US_RPR Receive Pointer Register RW
0x0104 UART_RCR US_RCR Receive Counter Register RW
0x0108 UART_TPR US_TPR Transmit Pointer Register RW
0x010C UART_TCR US_TCR Transmit Counter Register RW
0x0110 UART_RNPR US_RNPR Receive Next Pointer Register RW
0x0114 UART_RNCR US_RNCR Receive Next Counter Register RW
0x0118 UART_TNPR US_TNPR Transmit Next Pointer Register RW
0x011C UART_TNCR US_TNCR Transmit Next Counter Register RW
0x0120 UART_PTCR US_PTCR Transfer Control Register W
0x0124 UART_PTSR US_PTSR Transfer Status Register R


Эта "структура" не является специфичной для UART/USART. Точно такие же регистры есть и у другой периферии: ADC, DAC, SPI, TWI итд.  Для того, чтобы передать какой-то буфер с помощью DMA нужно записать в UART_TPR адрес начала буфера, а в UART_TCR количество байтов для передачи. Аналогично для приёма нужно записать в UART_RPR адрес начала буфера, куда будут записаваться принимаемые данные, а в UART_RCR -- количество принимаемых байт. С помощью регистра UART_PTCR можно управлять передачей данных. Например, записав в него бит UART_PTCR_TXTDIS мы остановим передачу данных, а записав бит UART_PTCR_TXTEN -- наоборот, запустим. Мы видим, что имеется "второй комплект" регистров для следующего передаваемого буфера: UART_RNPR, UART_RNCR, UART_TNPR, UART_TNCR. Они задают "следующие" передаваемые буфферы. Как только передача текущего буфера завершится, значения из из регистров UART_RNPR и UART_RNCR будут переданы в регистры UART_RPR и UART_RCR соответственно. Аналогично, регистры UART_TNPR, UART_TNCR будут перенесены в UART_TPR и UART_TCR и следующие буферы станут текущими.

Я создал репозиторий GitHub, в котором собрал кусочки с различными способами работы с UART (не заботясь о стиле и оформлении). Надеюсь этот пост поможет кому-то разобраться с UART в Cortex M3 микроконтроллерах фирмы Atmel. 

воскресенье, 13 ноября 2016 г.

"Низкоуровневое" программирование Arduino DUE без использования среды Arduino

Если вы работаете с платой Arduino Due, у вас может возникнуть необходимость или желание избавится либо от среды разработки либо от библиотеки Arduino. Причины могут быть разными: например, вам нужно использовать язык C, а не Arduino/C++. Или плата Arduino Due подключена к машине, к которой у вас есть доступ только через терминал и командную строку. А может быть вы просто хотите написать весь низкоуровневый код самостоятельно чтобы получше разобраться в тонкостях работы микроконтроллера Atmel sam3x8e.

Моя "среда разработки" состоит из платы Raspberry Pi к которой USB-кабелем может подключаться плата Arduino Due. Клавиатура и монитор отсутствуют, всё взаимодействие через SSH и коммандную строку. На Raspberry Pi установлена операционная система Raspbian (адаптированный под RPi Debian Linux). Raspberry Pi обеспечивает ряд удобств: во-первых, если что-то пойдёт не так, вы скорее всего сожжёте относительно недорогой RPi а не свой основной компьютер. Во-вторых RPi имеет на борту разные интерфейсы, такие как I2C, GPIO, UART итд, которые можно использовать для свзяи с Arduino DUE. При этом платы совместимы по уровню напряжения 3.3В. В третьих репозитории raspbian содержат всё необходимое для работы с Arduino Due.

Для работы нам необходим в первую очередь подходящий "тулчейн" -- набор инструментов состоящий из компилятора gcc и сопутствующих программ.  Установим его из репозитория Raspbian:

>sudo apt-get install gcc-arm-none-eabi

Также нам понадобится утилита коммандной строки bossac, которая служит для прошивания ARM-контроллеров фирмы Atmel. Ставим из репозитория и её:

>sudo apt-get install bossa-cli

В принципе у нас теперь есть необходимый минимум для того, чтобы написать какую-нибудь программу для sam3x8e/Arduino DUE собрать её и закинуть бинарник в память микроконтроллера, примерно так же как я описывал в предыдущих постах (раз, два) про lpc1114. Наш контроллер sam3x8e как и lpc1114 построен по архитектуре ARM Cortex M (только теперь у нас более навороченный Cortex M3, а в lpc1114 был Cortex M0). Нам придётся найти в даташите адреса регистров соответствующей периферии в sam3x8e и управлять ей записывая туда и считывая оттуда данные.

Заметно облегчить задачу может библиотека от производителя микроконтроллера. В простейшем случае она может предствалять из себя набор заголовочных файлоа (*.h файлов), в которых адресам регистров и числовым значениям, зоответствющим различным служебным маскам сопоставляется мнемонические константы. Кроме этого в библиотеку может входить набор относительно низкоуровневых функций и типов, сильно облегчающий и ускоряющий разработку. Библиотека Arduino построена "поверх" такой низкоуровневой библиотеки libsam. На сайте Atmel предлагается скачать библиотеку ASF (Atmel Software Foundation). Сравнивая эти две библиотеки я пришёл к выводу, что libsam либо предшественник либо старая версия ASF. В них довольно много одинакового кода.

Лично мне оказалось проще стартовать с библиотекой libsam, наверно потому что я "выкидывал лишнее" из библиотеки Arduino. Лучше всего воспользоваться готовым "скелетом" C/C++ программы для Ардуино. Мне удалось найти несколько таких "заготовок-скелетов".

http://www.atwillys.de/content/cc/using-custom-ide-and-system-library-on-arduino-due-sam3x8e/?lang=en -- статья о программировании платы Arduino Due на С. С этой странички можно скачать файлы для создания структуры директорий необходимой для сборки ваших C-файлов. Необходимо иметь установленную среду Arduino, сконфигурированную для работы с Arduino Due и интерпретатор php. Скрипт на php "наполняет" директорию проекта копирую в неё из директорий Arduino программы для сборки, закачивание прошивки а так же библиотеки sam3x8e.

https://github.com/sethm/arm_skeleton -- содержит libsam (заголовочные файлы и исходный код), Makefile и файл main.c, в который вам просто нужно поместить свой код. Об инструментах (arm-none-eabi-gcc и bossac) нужно позаботиться самостоятельно, но это не проблема так как они есть в репозитории Raspbian. Нет никаких внешних зависимостей, поэтому всё заводится без проблем. Я остановился на этом "каркасе" для своих проектов и здесь будем рассматривать именно его.

https://github.com/pauldreik/arduino-due-makefile -- по сути просто один Makefile, который позволяет собирать Arduino-проекты (из файлов *.ino) используя лишь командную строку.

Итак, устанавливает "каркасс" по второй ссылке.

>git clone https://github.com/sethm/arm_skeleton.git

в файле main.c реализована программа "плавного" мигания светодиодом с помощью ШИМ. Меняем на свою, супер-примитивную программу мигания светодиодом без всякого ШИМ и плавности.



Для редактирования исходного кода через ssh-подключение я использую редактор vim.


Собираем:

>make

Если сборка прошла без ошибок, то в директории проекта появится файл main.c.bin. Подключаем "programming port" Arduino Due к USB-выходу Raspberry Pi. Прошиваем:

>make prog

После того, как прошивка будет отправлена в контроллер, он будет автоматически сброшен и светодиод должен начать мигать в соответствии с нашей программой.

Итог: мы научились программировать Arduino Due "на низком уровне" без среды и библиотек Arduino.

среда, 21 сентября 2016 г.

Обучаем нейросеть распознавать рукописные цифры с помощью Torch7


В "песочнице" kaggle.com есть задача на классификацию рукописных цифр с использованием набора данных MNIST. Нужно обучить классификатор, которые по изображению написанной цифры размером 28x28 определит, что это за цифра. Это неплохая возможность для знакомства со свёрточными нейросетями, тем более, что именно они дают самую высокую точность на этой задаче.

Итак, имеется датасет размером 70000 цифр. Он условно разбит на тренировочное (42000 цифр) и тестовое (28000 цифр) множества. "Условно", потому что этот набор данных опубликован полностью, поэтому при желании можно найти разметку для тестовых данных, но мы не будем читить.

В качестве baseline возьмём такое простое, но очень эффективное решение пользователя Zhao Hanguang. Это решение использует связку метода главных компонент и SVM, работает очень быстро и всего на 35 главных компонентах даёт впечатляющие 98.243% точности. Задача минимум была побить это решение. Задача максимум - достичь 99% точности. Главной целью этих упражнений было знакомство с пакетом Torch7 и его возможностями "глубинного обучения". Сложным моментом было отсутствие в моём распоряжении машины с поддержкой CUDA, так что приходилось ограничиваться небольшими нейросетями и ждать результатов по несколько часов.

Torch предоставляет интерактивную среду с интерпретатором Lua: всё, что вы вводите интерпретируется как строка Lua и сразу выполняется. В ваших Lua-командах вы можете манипулировать объектами-тензорами (по сути просто многомерными массивами). Операции с тензорами написаны на C (есть реализации, которые используют GPU). Создатели Torch позиционируют его как "матлаб-подобную среду для машинного обучения".

Удобнее всего работать с Torch через веб-оболочку iTorch (аналогичную iPython).

Для подобных задач Torch содержит несколько пакетов: optim для поиска максимумов и минимумов, image работы и изображениями и nn для работы с нейросетями. За основу я взял пример из GitHub проекта Torch.

Итак, подключаем все модули:

--Import all dependencies
require 'nn'
require 'optim'
require 'csvigo'
require 'image'

Модуль csvigo помогает загружать данные в формате csv.

Загружаем данные

--Read datasets
train_data = csvigo.load({path = "~/data/train.csv", mode = "large"})
test_data = csvigo.load({path = "~/data/test.csv", mode = "large"})        

Извлекаем из данных отдельные колонки: признаки и разметку для тренировочных данных, а так же конвертируем таблицы Lua в тензоры Torch

--Create tensors for train and test data
train_feature_tensor = torch.Tensor(#train_data-1, 784)
train_label_tensor = torch.Tensor(#train_data-1, 1)
test_feature_tensor = torch.Tensor(#test_data-1, 784)
--Fill tensor with train and test data from file
for i=2,#train_data do
    train_feature_tensor[{i-1,{}}] = torch.Tensor(train_data[i]):narrow(1,2,784)
    train_label_tensor[i-1] = train_data[i][1]
end
 
for i=2,#test_data do
    test_feature_tensor[{i-1,{}}] = torch.Tensor(test_data[i])
end

Модуль image позволяет легко визуализировать картинки, представленные тензорами. Проверим, что наши цифры загрузились как ожидаолсь...

--Check data. We should see handwritten digits
itorch.image(train_feature_tensor[524]:resize(28,28))
itorch.image(test_feature_tensor[231]:resize(28,28))
 


Теперь задаём нейросеть.


--Create Neural network model
model = nn.Sequential()
 
model:add(nn.SpatialConvolution(1, 16, 5, 5)) --28x28x1 goes in, 24x24x16 goes out
model:add(nn.ReLU()) -- 
model:add(nn.SpatialMaxPooling(2, 2, 2, 2)) --24x24x16 goes in, 12x12x16 goes out
model:add(nn.Dropout(0.2))
 
model:add(nn.SpatialConvolution(16, 32, 5, 5)) --12x12x16 goes in, 8x8x32 goes out
model:add(nn.ReLU()) --
model:add(nn.SpatialMaxPooling(2, 2, 2, 2)) --8x8x32 goes in, 4x4x32 goes out
model:add(nn.Dropout(0.2))
 
model:add(nn.View(4*4*32))
model:add(nn.Linear(4*4*32, 64))
model:add(nn.ReLU()) --
model:add(nn.Dropout(0.2))
model:add(nn.Linear(64, 20))
model:add(nn.ReLU())
model:add(nn.Linear(20, 10))
model:add(nn.LogSoftMax())
 

Можно, например полюбопытствовать, и узнать, сколько всего параметров имеет наша нейросеть. Это, например, сможет дать некоторую интуицию по поводу переобучения.

In [14]:
model:getParameters():size()
Out[14]:
 47590
[torch.LongStorage of size 1]


Видим, что при обучении нам надо будет подобрать 47590 параметров. В Torch индексы отсчитываются от 1. Классы в классификации тоже нумеруются начиная от 1, а у нас из csv приходят 0-based метки классов. Поправим...

train_label_tensor = train_label_tensor + 1 

Во время своих первых экспериментов я столкнулся с такими проблемами: нейросети более простой архитектуры чуть-чуть не дотягивали до baseline, более сложные нейросети (например такая, как описана выше) переобучались: на тренировочном множестве они могли достигать 99%, но на тестовом отставали примерно на полпроцента. Переобучение налицо. Сначала я пытался решить проблему подбором параметра регуляризации и dropout-слоями, но 99% на тестовой выборке не достиг. Тогда я стал "деформировать" цифры из тренировочного множества чтобы увеличить разнообразие обучаюзих примеров. Для этого я нашёл готовую функцию на GitHub (автор - пользователь chsasank).

Кстати, это очень сильно увеличило время тренировки, и оно достигло нескольких часов.Но именно это дало в итоге последние доли процента на тестовых данных.

Обучение нейросети делал "явно". То есть в модуле nn нет "высокоуровневой" операции "тренировать нейросеть". Вместо этого приходится определять градиент функции стоимости и передвать его методам пакета optim. К счастью для этого не надо руками писать метод обратного распространения ошибки, он уже есть в пакете nn (nn,backward).

batchSize = 512
trainSize = train_feature_tensor:size()[1]
batchInputs = torch.Tensor(batchSize, 1, 28, 28)
batchLabels = torch.Tensor(batchSize)
lambda = 0.0005
 
-- this matrix records the current confusion across classes
confusion = optim.ConfusionMatrix(classes)
 
local params, gradParams = model:getParameters()
local optimState = {learningRate=0.04}
for epoch=1,450 do
  --local optimState = {learningRate=0.04 - epoch/1000.0 * 0.03}
  print("Epoch:"..epoch)
  for b = 1,math.ceil(trainSize/batchSize) do
    for i=1,batchSize do
      local originalImage = torch.Tensor(1, 28, 28)
      originalImage:copy(train_feature_tensor[(b*batchSize + i - 1) % trainSize + 1])
      batchInputs[i] = ElasticTransform(originalImage, 100, 10) 
      batchLabels[i] = train_label_tensor[(b*batchSize + i - 1) % trainSize + 1]
    end
 
    --Differentiation
    local function feval(params)
      gradParams:zero()
 
      local outputs = model:forward(batchInputs)
      local loss = criterion:forward(outputs, batchLabels)
      local dloss_doutput = criterion:backward(outputs, batchLabels)
      model:backward(batchInputs, dloss_doutput)
 
      --Regularization\n",
      loss = loss + 0.5 * lambda * torch.norm(params,2)^2 / batchSize;
      gradParams:add( params:clone():mul(lambda) )
 
      -- update confusion
      for i = 1,batchSize do
        confusion:add(outputs[i], batchLabels[i])
      end
 
      return loss,gradParams
    end
    optim.sgd(feval, params, optimState)
 
 -- Too big output
 --   print(confusion)
 --   print("Total valid: "..confusion.totalValid * 100)
    confusion:zero()
  end
end
 

Этот код в основном взят из примеров Torch. Тренировка на моей машине занимает несколько часов. Оценить результаты можно с помощью матрицы ошибок. Сначала посмотрим матрицу на тренировочных данных:

confusion = optim.ConfusionMatrix(classes)
-- test function
function test(eval_features, eval_labels)
  print(eval_features:size())
  print(eval_labels:size())
  confusion:zero()
  -- test samples
  local preds = model:forward(eval_features)
 
  local maxval, pred_idx = torch.max(preds, 2)
 
  -- confusion:
  for i = 1,eval_features:size()[1] do
   -- print("Add: ", pred_idx[i][1] , eval_labels[i][1])  
    confusion:add(pred_idx[i][1], eval_labels[i][1])
  end
 
   -- print confusion matrix
   print(confusion)
   --confusion:zero()
end

train_features_resized = train_feature_tensor:resize(42000,1,28,28)
test(train_features_resized:narrow(1,1,20000), train_label_tensor:narrow(1,1,20000))

Функция test выведет матрицу ошибок. У меня в машине не хватало памяти для того, чтобы пропустить через сеть больше примрно 20000 примеров за раз, так что для матрицы ошибок я взял первые 20000 примеров. В тестовой выборке 28000 примеров, поэтому её пришлось разбить на два куска. Классификация осуществляется методом forward объекта-модели. Насколько я разобрался, метод forward меняет "на месте" (in-place) внутреннее поле объекта и возвращает ссылку на него, поэтому в результате выполнения такого кода

A = model:forward(F1)
B = model:forward(F2)

Объекты A и B будут одинаковы и равны результату model:forward(F2).
"Тетрадка" iTorch выложена на GitHub.