在公司项目里的一个小角落,尝试使用 Rust + WebAssembly 加速应用里部分计算

选择

首先要分清项目里 WebAssembly 的定位

  1. 应用主体:完成绝大多数逻辑,JS 只作为加载入口和少部分事件绑定桥梁
  2. 工具库:分担一些复杂和耗时的计算,由 JS 决定什么时候调用 wasm 模块函数,通常情况下,由 JS 管理 wasm 的线性内存

此次小 Demo 属于第 2 种情况

选择 Rust Target

rust 支持 wasm32-unknown-unknownwasm32-unknown-emscripten 两种编译目标,后者除了wasm 外,还生成了 emscripten 风格的 JS 作为入口,与 asm.js 的调用风格统一。不过会添加好些运行时代码,与 wasm 之间也隔着一层封装。基于上一项的选择,我们选择前者

1
rustup target add wasm32-unknown-unknown

WebAssembly 考量

数据类型问题

WebAssembly 当前只支持几种有限的数字类型,i32/i64/f32/f64,JS 与之交互时,除了 number 以外的值都要有序列化/反序列化处理。

字符串

可使用 TextEncoder/TextDecoder 将 JS 字符串序列化为 utf-8 字节流,在 Rust 端先将字节流解析为字符串,再使用。

数组

JS 直接操作 wasm 示例的内存,将数组数据写入,调用 wasm 方法时,将数组起始的指针以及数组长度作为参数。

1
2
3
4
5
6
// rust 部分
pub unsafe fn load_image_data(in_image_ptr: *mut u8, width: i32, height: i32) -\> *const u8 {
let arr_len = (width * height * 4) as usize;
let in_image_data = Vec::from_raw_parts(in_image_ptr, arr_len, arr_len);
return in_image_data.as_ptr();
}
1
2
3
4
5
6
// js 部分
const ctx = canvas.getContext('2d')
const imageData = ctx.getImageData(0, 0, 100, 50)
const imgDataVecPtr = copyJsArrayToRust(this.instanceExports, imageData.data)
wasmExports.load_image_data(imgDataVecPtr, 100, 50)
`

上代码

Rust 端

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// main.rs
use std::{mem};
use std::ffi::CString;
use std::os::raw::{c_char};

extern {
fn clog(ptr: *const u8, number: usize);
}

fn js_log(s: String) {
let mut _s = s.clone();
unsafe {
let m = _s.as_mut_vec().as_mut_ptr();
clog(m as *const u8, _s.len());
}
}

#[no_mangle]
pub fn alloc(size: usize) -> *const u8 {
let buf = Vec::with_capacity(size);
let ptr = buf.as_ptr();
mem::forget(buf); // 让 rust 放弃对此段内存的控制权,此函数结束后该段内存对于 rust 来说是泄漏的状态,分配和管理权交给 JS
return ptr;
}

#[no_mangle]
pub fn log_something(text_ptr: *mut c_char) -> *const u8 {
let text = CString::from_raw(text_ptr).into_string().unwrap();
js_log(text);
}

fn main() {} // 留一个 main 确保 rust 能正常编译

JS 端

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
// wasm-util.ts

export interface ModuleExports {
memory: WebAssembly.Memory
alloc(len: number): WasmMemPtr
log_something(text_ptr: number)
}

export function loadWebAssembly<T extends ModuleExports>(input: string, imports: any) {
// Fetch the file and compile it

return fetch(input).then(response => response.arrayBuffer())
.then(buffer => {
// Create the imports for the module, including the
// standard dynamic library imports
imports = imports || {}
imports.env = imports.env || {}
if (!imports.env.memory) {
imports.env.memory = new WebAssembly.Memory({ initial: 1 })
}

// Create the instance.
return WebAssembly.instantiate(buffer, imports)
})
}

export function copyJsStringToRust(module: ModuleExports, str: string) {
const utf8Encoder = new TextEncoder()
const string_buffer = utf8Encoder.encode(str)
const len = string_buffer.length
const ptr = module.alloc(len + 1)

const memory = new Uint8Array(module.memory.buffer, ptr)
for (let i = 0; i < len; i++) {
memory[i] = string_buffer[i]
}

memory[len] = 0 // cstring end

return ptr
}

export function decodeRustString(module: ModuleExports, ptr: WasmMemPtr) {
const collectCString = function*() {
const memory = new Uint8Array(module.memory.buffer)
while (memory[ptr] !== 0) {
if (memory[ptr] === undefined) {
throw new Error('Tried to read undef mem')
}
yield memory[ptr]
ptr += 1
}
}

const buffer_as_u8 = new Uint8Array(collectCString())
const utf8Decoder = new TextDecoder()
const buffer_as_utf8 = utf8Decoder.decode(buffer_as_u8)
return buffer_as_utf8
}

export function copyJsArrayToRust(exports: ModuleExports, arr: number[] | Uint8ClampedArray) {
const { memory, alloc } = exports
const rVecPtr = alloc(arr.length)
const asBytes = new Uint8Array(memory.buffer, rVecPtr, arr.length)
asBytes.set(arr)
return rVecPtr
}

demo.ts

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import {copyJsStringToRust, decodeRustString} from './wasm-util'

const importObj = {
env: {
clog: (ptr: WasmMemPtr, number: number) => {
const str = decodeRustString(this.rut.instanceExports, ptr)
console.log("[rust] " + str)
},
}
}

loadWebAssembly('demo.wasm').then(() => {
const strPtr = copyJsStringToRust(wasmExports, 'Heyhey you you')
wasmExports.log_something(strPtr)
})

// '[rust] Heyhey you you'

编译 WASM

1
cargo rustc --release --target=wasm32-unknown-unknown

WASM 瘦身

参见此文

Cargo.toml 中一些设定

1
2
3
[profile.release]
debug = false
lto = true

例子