weights[[buffer(0)]], constant float& bias[[buffer(1)]], uint2 gid[[thread_position_in_grid]]) { if (gid.x >= in.get_width() || gid.y >= in.get_height()) return; float partial = bias; for (uint i = 0; i < in.get_array_size(); ++i) { float3 in0 = float3(in.read(gid + uint2(-1, -1), i).r, in.read(gid + uint2( 0, -1), i).r, in.read(gid + uint2(+1, -1), i).r); float3 in1 = float3(in.read(gid + uint2(-1, 0), i).r, in.read(gid + uint2( 0, 0), i).r, in.read(gid + uint2(+1, 0), i).r); float3 in2 = float3(in.read(gid + uint2(-1, +1), i).r, in.read(gid + uint2( 0, +1), i).r, in.read(gid + uint2(+1, +1), i).r); float3x3 weight = weights[i]; partial += dot(in0, weight[0]) + dot(in1, weight[1]) + dot(in2, weight[2]); } float p = fmax(partial, 0) + 0.1 * fmin(partial, 0); float4 outColor(p, 0, 0, 0); out.write(outColor, gid); }