La scarsità riveste un ruolo cruciale per ottenere una maggiore efficienza nel deep learning. Tuttavia, per comprendere appieno il suo vero potenziale e utilizzare la scarsità nella vita reale, è necessario combinare ricerca su hardware, software e algoritmi in modo perfetto.
A tale scopo, una libreria versatile e flessibile è fondamentale. Google AI ha compiuto un passo avanti in questa direzione con JaxPruner, la libreria concisa di Google AI dedicata alla ricerca sull’apprendimento automatico.
JaxPruner è una libreria open source basata su JAX per il pruning e l’addestramento sparsi. Il suo obiettivo principale è la riduzione dei parametri sparsi. La libreria mira a migliorare il lavoro di ricerca sulle reti neurali sparse offrendo soluzioni concise per i noti metodi di pruning e addestramento sparsi.
La popolare libreria di ottimizzazione Optax e gli algoritmi utilizzati in JaxPruner condividono la stessa API, semplificando l’integrazione di JaxPruner con altre librerie basate su JAX.
Secondo il documento, la ricerca combina metodi per ottenere la riduzione dei parametri: pruning e addestramento sparsi. Il pruning cerca di generare reti neurali sparse a partire da reti dense per migliorare l’inferenza. L’addestramento sparso si concentra invece sulla creazione di reti neurali sparse da zero, riducendo contemporaneamente i costi di addestramento.
La comunità scientifica e di ricerca si è affidata molto a JAX negli ultimi anni. Le sue caratteristiche e funzioni uniche lo distinguono da altri framework rinomati come TensorFlow e PyTorch. La sua indipendenza dai dati lo rende un’ottima scelta per l’accelerazione hardware. Ciò riduce il tempo necessario per implementare idee complesse, semplificando operazioni come il calcolo dei gradienti, degli hessiani o la vettorizzazione (Babuschkin et al., 2020). Allo stesso tempo, è facile modificare una funzione quando il suo stato completo è contenuto in un’unica posizione.
Nonostante tecniche specifiche come il pruning della grandezza globale (Kao, 2022) e l’addestramento sparso con la riduzione e la quantizzazione N:M (Lew et al., 2022), non esiste ancora una libreria completa per la ricerca sulla riduzione dei parametri in JAX. Ciò ha portato all’introduzione di JaxPruner.
Con JaxPruner, gli scienziati desiderano affrontare domande cruciali come “Quale modello di riduzione dei parametri raggiunge il giusto compromesso tra accuratezza e prestazioni?” e “È possibile addestrare reti neurali sparse senza dover prima addestrare un modello denso di grandi dimensioni?” Per raggiungere questi obiettivi, tre principi hanno guidato lo sviluppo della libreria:
- Ridurre l’attrito per coloro che integrano JaxPruner in basi di codice esistenti. Per farlo, JaxPruner sfrutta la nota libreria di ottimizzazione Optax, che richiede poche modifiche se integrata con altre librerie.
- Offrire un’API generica condivisa tra vari algoritmi per consentire un facile passaggio da un algoritmo all’altro.
- Favorire l’ottimizzazione del deep learning e della riduzione dei parametri con il rilascio di JaxPruner, una libreria di ricerca sull’apprendimento automatico. Questa scoperta facilita il raggiungimento della promessa di riduzione dei parametri nelle applicazioni del mondo reale, aprendo le porte a una migliore collaborazione tra accademici che lavorano su hardware, software e algoritmi. JaxPruner consente alle aziende di sfruttare i vantaggi della riduzione dei parametri nelle reti neurali, accelerando le conversioni di funzioni, agevolando la prototipazione rapida e garantendo un’integrazione senza soluzione di continuità con le basi di codice esistenti.