mardi 19 mai 2009

Programmation : 2ème partie, de l'importance de l'optimisation

Suite de ce billet, on va rentrer dans le vif du sujet.
Je ne vais pas détailler l'algorithme mais le code doit calculer les 16 points permettant de calculer l'amplitude finale, cela consiste donc à multiplier chaque point source par l'échantillon et à multiplier chaque point de destination par le sample, à ce stade, nous avons les deux amplitudes (source et destination), il ne reste plus qu'à calculer l'amplitude qui consiste à obtenir la distance entre ces deux points (+ partie fractionnelle).
Voici le code original, en PureBasic, oui je sais, inconnu au bataillon tout ça mais peu importe, n'importe quel programmeur devrait comprendre le code, plutôt simple :

#FP_SHIFT = 15

; init sinc table
SAMPLE_init_sinc_table()

; init sample data
*sample_data = AllocateMemory(16 * 2)

; init mix buffer (4 samples)
; LR - LR - LR - LR
*mix_buffer = AllocateMemory(8 * SizeOf(Long))

; alimentation de données bidon
PokeW(*sample_data + 0, 10)
PokeW(*sample_data + 2, 20)
PokeW(*sample_data + 4, 30)
PokeW(*sample_data + 6, 40)
PokeW(*sample_data + 8, 50)
PokeW(*sample_data + 10, 60)
PokeW(*sample_data + 12, 70)
PokeW(*sample_data + 14, 80)

PokeW(*sample_data + 16, 90)
PokeW(*sample_data + 18, 100)
PokeW(*sample_data + 20, 110)
PokeW(*sample_data + 22, 120)
PokeW(*sample_data + 24, 130)
PokeW(*sample_data + 26, 140)
PokeW(*sample_data + 28, 150)
PokeW(*sample_data + 30, 160)


size_of_word = SizeOf(Word)
table_idx = 0
sample_idx = 0
offset = 0
left_gain = 2408
right_gain = 5156

start0 = GetTickCount_()
For i = 0 To 10000000

While offset < 8

*sinc_table_ptr = *sinc_table + ( ( table_idx + 0 ) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 0) * size_of_word )
a1 = (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 16) * size_of_word )
a2 = (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 1 ) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 1) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 17) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 2 ) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 2) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 18) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 3 ) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 3) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 19) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 4 ) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 4) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 20) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 5 ) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 5) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 21) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 6 ) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 6) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 22) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 7 ) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 7) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 23) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 8 ) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 8) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 24) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 9 ) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 9) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 25) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 10) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 10) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 26) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 11) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 11) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 27) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 12) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 12) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 28) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 13) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 13) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 29) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 14) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 14) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 30) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

*sinc_table_ptr = *sinc_table + ( ( table_idx + 15) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 15) * size_of_word )
a1 + (*sinc_table_ptr\w * *sample_data_ptr\w)
*sinc_table_ptr = *sinc_table + ( ( table_idx + 31) * size_of_word )
a2 + (*sinc_table_ptr\w * *sample_data_ptr\w)

a1 >> #FP_SHIFT
a2 >> #FP_SHIFT

; SIMPLIFICATION du code (code bidon, ne fonctionne pas)
amplitude = a1 + a2

*mix_buffer_ptr = *mix_buffer + (offset * SizeOf(Long))
newsample = (amplitude * left_gain) >> #FP_SHIFT
*mix_buffer_ptr\l + newsample

*mix_buffer_ptr = *mix_buffer + ((offset + 1) * SizeOf(Long))
newsample = (amplitude * right_gain) >> #FP_SHIFT
*mix_buffer_ptr\l + newsample

offset + size_of_word

; fake sinc table_idx
table_idx + 1

;Debug "SINC_NO_SSE:: A1 = " + Str(a1) + " A2 = " + Str(a2)

Wend

offset = 0
table_idx = 0
sample_idx = 0
Next i
end0 = GetTickCount_()

; result of mixing
amp1 = PeekL(*mix_buffer + 0 * SizeOf(long) )
amp2 = PeekL(*mix_buffer + 1 * SizeOf(long) )
amp3 = PeekL(*mix_buffer + 2 * SizeOf(long) )
amp4 = PeekL(*mix_buffer + 3 * SizeOf(long) )
amp5 = PeekL(*mix_buffer + 4 * SizeOf(long) )
amp6 = PeekL(*mix_buffer + 5 * SizeOf(long) )
amp7 = PeekL(*mix_buffer + 6 * SizeOf(long) )
amp8 = PeekL(*mix_buffer + 7 * SizeOf(long) )

Debug "SINC_NO_SSE:: SSE_CALC amp1 = " + Str(amp1) + " amp2 = " + Str(amp2) + " amp3 = " + Str(amp3) + " amp4 = " + Str(amp4)
Debug "SINC_NO_SSE:: SSE_CALC amp5 = " + Str(amp5) + " amp6 = " + Str(amp6) + " amp7 = " + Str(amp7) + " amp8 = " + Str(amp8)



Explication :
- initialisation d'un tableau contenant les points pré calculés pour des échantillons sur 16 bits ( SAMPLE_init_sinc_table() ),
- allocation d'un tampon contenant les échantillons en entrée (8 échantillons car un échantillons occupe 2 octets * 2 car il est stéréo au format LR (Left - Right) ),
- allocation d'un tampon en sortie au format entier pour stocker pour un entier, l'échantillon stéréo donc un int contient deux shorts (LR),
- alimentation via des données bidons (les PokeW(*sample_data + 0, 10) ... ),
- initialisation de différentes valeurs pour le test,
- puis vient la partie critique, on boucle 10000000 (pour le bench), à chaque itération, nous bouclons 8 fois pour récupérer 2*8 shorts d'échantillons qui seront stockés dans notre tampons de mixage, pour chaque itération, nous calculons le point source et destination, je le scale sur 16 bits (-1 pour le bit de signe) et je stocke le tout dans le mélangeur.

Un gros pâte pour finalement pas grand chose.
Maintenant, voyons ce que ça donnerais en SSE2 :

Procedure.l PB_sse2_mul_add(*p1.Word, *p2.Word, size.l)
Protected *dest_buffer_ptr.Long
Protected idx.l, size_of_elements = size / 8

*dest_buffer_ptr = @dest_buffer()

!MOV esi, [p.p_p1]
!MOV edi, [p.p_p2]
!MOV edx, [p.p_dest_buffer_ptr]
!XOR ecx, ecx
!PXOR xmm4, xmm4

For idx = 0 To size_of_elements - 1
!MOVUPS xmm0, [esi+ecx] ; +16
!MOVUPS xmm1, [esi+ecx] ; +16

!MOVUPS xmm2, [edi+ecx] ; +16

!PMULLW xmm0, xmm2
!PMULHW xmm1, xmm2

!MOVUPS xmm3, xmm0 ; copy orig xmm0

!PUNPCKLWD xmm0, xmm1 ; xmm0 --> loword (this is why it is was copied :p )
!PUNPCKHWD xmm3, xmm1

!PADDD xmm4, xmm0
!PADDD xmm4, xmm3

!MOVUPS [edx], xmm4

!ADD ecx, 16
Next idx

result = PeekL(*dest_buffer_ptr) + PeekL(*dest_buffer_ptr+4) + PeekL(*dest_buffer_ptr+8) + PeekL(*dest_buffer_ptr+12)

ProcedureReturn result
EndProcedure

;- calc amplitude 0
*sinc_table_ptr = *sinc_table + ( ( table_idx + 0 ) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 0) * size_of_word )

a1 = PB_sse2_mul_add(*sinc_table_ptr, *sample_data_ptr, 16)
a2 = PB_sse2_mul_add(*sinc_table_ptr+32, *sample_data_ptr, 16)

a1 >> #FP_SHIFT
a2 >> #FP_SHIFT

calc_amp(0) = a1 + a2

; fake sinc table_idx
table_idx + 1

Debug "SINC_SSE:: A1 = " + Str(a1) + " A2 = " + Str(a2)

;- calc amplitude 1
*sinc_table_ptr = *sinc_table + ( ( table_idx + 0 ) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 0) * size_of_word )

a1 = PB_sse2_mul_add(*sinc_table_ptr, *sample_data_ptr, 16)
a2 = PB_sse2_mul_add(*sinc_table_ptr+32, *sample_data_ptr, 16)

a1 >> #FP_SHIFT
a2 >> #FP_SHIFT

calc_amp(1) = a1 + a2

; fake sinc table_idx
table_idx + 1

;- calc amplitude 2
*sinc_table_ptr = *sinc_table + ( ( table_idx + 0 ) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 0) * size_of_word )

a1 = PB_sse2_mul_add(*sinc_table_ptr, *sample_data_ptr, 16)
a2 = PB_sse2_mul_add(*sinc_table_ptr+32, *sample_data_ptr, 16)

a1 >> #FP_SHIFT
a2 >> #FP_SHIFT

calc_amp(2) = a1 + a2

; fake sinc table_idx
table_idx + 1

;- calc amplitude 3
*sinc_table_ptr = *sinc_table + ( ( table_idx + 0 ) * size_of_word )
*sample_data_ptr = *sample_data + ( ( sample_idx + 0) * size_of_word )

a1 = PB_sse2_mul_add(*sinc_table_ptr, *sample_data_ptr, 16)
a2 = PB_sse2_mul_add(*sinc_table_ptr+32, *sample_data_ptr, 16)

a1 >> #FP_SHIFT
a2 >> #FP_SHIFT

calc_amp(3) = a1 + a2

; fake sinc table_idx
table_idx + 1

*calc_amp_ptr = @calc_amp(0)
*mix_buffer_ptr = *mix_buffer + (0 * SizeOf(Long))
*left_gain_ptr = @left_gain
*right_gain_ptr = @right_gain

; sse mixing

; L - L - L - L
; 159 - 139 - 119 - 99
;* 2048 - 2048 - 2048 - 2048
; ---- ----- ----- -----
;325632 284672 243712 202752

; R - R - R - R
; 159 - 139 - 119 - 99
;* 5156 - 5156 - 5156 - 5156
; ---- ----- ----- -----
; 819804 716684 613564 510444

start1 = GetTickCount_()
For i = 0 To 10000000

!MOV esi, [p_calc_amp_ptr]
!MOV edi, [p_mix_buffer_ptr]
!MOV ebx, [p_left_gain_ptr]
!MOV edx, [p_right_gain_ptr]
!PXOR xmm0, xmm0
; !PXOR xmm1, xmm1
!PXOR xmm2, xmm2
!PXOR xmm3, xmm3
!PXOR xmm4, xmm4
; !PXOR xmm5, xmm5

; store
!MOVUPS xmm0, [esi] ; amp0 - amp1 - amp2 - amp3 (LLLL)
!MOVUPS xmm4, xmm0 ; amp0 - amp1 - amp2 - amp3 (RRRR)
!MOVUPS xmm1, [edi] ; sample0 - sample1 - sample2 - sample3
!MOVUPS xmm5, [edi+16] ; sample4 - sample5 - sample6 - sample7
!MOVUPS xmm2, [ebx] ; ? - ? - ? - left_gain
!MOVUPS xmm3, [edx] ; ? - ? - ? - right_gain

; prepare shift
!PSHUFD xmm2, xmm2, 00b ; left_gain - left_gain - left_gain - left_gain
!PSHUFD xmm3, xmm3, 00b ; right_gain - right_gain - right_gain - right_gain

; calc amplitude
!cvtdq2ps xmm2, xmm2 ; int --> single
!cvtdq2ps xmm3, xmm3 ; int --> single

!cvtdq2ps xmm0, xmm0 ; int --> single
!cvtdq2ps xmm4, xmm4 ; int --> single

; perform 4 singles mul (LLLL)
!mulps xmm0, xmm2 ; amp0*left_gain - amp1*left_gain - amp2*left_gain - amp3*left_gain (create LLLL sample)
; perform 4 singles mul (RRRR)
!mulps xmm4, xmm3 ; amp0*right_gain - amp1*right_gain - amp2*right_gain - amp3*right_gain (create RRRR sample)

!cvtps2dq xmm0, xmm0
!cvtps2dq xmm4, xmm4

; calc fractional bits
!PSRAD xmm0, 15 ; amp0 >> #FP_SHIFT - amp1 >> #FP_SHIFT - amp2 >> #FP_SHIFT - amp3 >> #FP_SHIFT (LLLL)
!PSRAD xmm4, 15 ; amp0 >> #FP_SHIFT - amp1 >> #FP_SHIFT - amp2 >> #FP_SHIFT - amp3 >> #FP_SHIFT (RRRR)

; save orig values (for low 32 bits word)
!MOVUPS xmm2, xmm0
!MOVUPS xmm3, xmm4

; unpack orig 32 bits low words to saved 32 bits high words (LRLR for 1st sample)
!UNPCKLPS xmm0, xmm4 ; LRLR --> xmm0 (2 samples)
; unpack orig 32 bits high words to saved 32 bits low words (LRLR for 2nd sample)
!UNPCKHPS xmm2, xmm3 ; LRLR --> xmm2 (2 samples)

; mix result : 2 samples mixed at a time
!PADDD xmm1, xmm0
!PADDD xmm5, xmm2

!MOVUPS [edi], xmm1
!MOVUPS [edi+16], xmm5

Next i
end1 = GetTickCount_()

; result of mixing
amp1 = PeekL(*mix_buffer + 0 * SizeOf(long) )
amp2 = PeekL(*mix_buffer + 1 * SizeOf(long) )
amp3 = PeekL(*mix_buffer + 2 * SizeOf(long) )
amp4 = PeekL(*mix_buffer + 3 * SizeOf(long) )
amp5 = PeekL(*mix_buffer + 4 * SizeOf(long) )
amp6 = PeekL(*mix_buffer + 5 * SizeOf(long) )
amp7 = PeekL(*mix_buffer + 6 * SizeOf(long) )
amp8 = PeekL(*mix_buffer + 7 * SizeOf(long) )

Debug "SINC_SSE:: SSE_CALC amp1 = " + Str(amp1) + " amp2 = " + Str(amp2) + " amp3 = " + Str(amp3) + " amp4 = " + Str(amp4)
Debug "SINC_SSE:: SSE_CALC amp5 = " + Str(amp5) + " amp6 = " + Str(amp6) + " amp7 = " + Str(amp7) + " amp8 = " + Str(amp8)


Désolé pour les commentaires mais qu'est ce que cela signifie :
- 1ère différence : c'est LA clef de notre optimisation, on voit que je calcule 4 amplitude en appelant 4 fois la procédure PB_sse2_mul_add(), celle-ci me calcule la partie que j'ai mis en gras plus haut. Cela veut dire que, contrairement au code original, je ne calcule pas qu'une amplitude par itération mais 4 simultanément,
- 2ème différence : je stockais un échantillon au format LR, la, vu que j'ai calculé 4 amplitudes, j'en stocke 4 par itération au format LRLR-LRLR au lieu de LR ... tout court !
- ... et c'est tout !

Mais ces deux simples différences sont éloquentes :
- code PureBasic : 6520 ms en moyenne,
- code SSE2 : 170 ms en moyenne.

Soit 38x plus rapide, rien que ça ! Cela ne se voit pas que dans les benchs mais aussi dans le taux d'occupation cpu : j'oscille entre 2 à 5% (en pique), cela peut sembler encore beaucoup mais c'est tout de même deux fois moins gourmand, objectif réussi :)
Je vous laisse zieuter le code, n'oubliez pas google, c'est grâce à lui que je m'en suis sortie mais je reste disponible pour toute aide que je pourrais apporter :)

Aucun commentaire: