defmodule Postgrex.Messages do @moduledoc false import Postgrex.BinaryUtils import Record, only: [defrecord: 2] @protocol_vsn_major 3 @protocol_vsn_minor 0 @auth_types [ ok: 0, kerberos: 2, cleartext: 3, md5: 5, scm: 6, gss: 7, gss_cont: 8, sspi: 9, sasl: 10, sasl_cont: 11, sasl_fin: 12 ] @error_fields [ severity: ?S, code: ?C, message: ?M, detail: ?D, hint: ?H, position: ?P, internal_position: ?p, internal_query: ?q, where: ?W, schema: ?s, table: ?t, column: ?c, data_type: ?d, constraint: ?n, file: ?F, line: ?L, routine: ?R ] defrecord :msg_auth, [:type, :data] defrecord :msg_startup, [:params] defrecord :msg_password, [:pass] defrecord :msg_error, [:fields] defrecord :msg_parameter, [:name, :value] defrecord :msg_backend_key, [:pid, :key] defrecord :msg_ready, [:status] defrecord :msg_notice, [:fields] defrecord :msg_query, [:statement] defrecord :msg_parse, [:name, :statement, :type_oids] defrecord :msg_describe, [:type, :name] defrecord :msg_flush, [] defrecord :msg_close, [:type, :name] defrecord :msg_parse_complete, [] defrecord :msg_parameter_desc, [:type_oids] defrecord :msg_too_many_parameters, [:len, :max_len] defrecord :msg_row_desc, [:fields] defrecord :msg_no_data, [] defrecord :msg_notify, [:pg_pid, :channel, :payload] defrecord :msg_bind, [:name_port, :name_stat, :param_formats, :params, :result_formats] defrecord :msg_execute, [:name_port, :max_rows] defrecord :msg_sync, [] defrecord :msg_bind_complete, [] defrecord :msg_close_complete, [] defrecord :msg_portal_suspend, [] defrecord :msg_data_row, [:values] defrecord :msg_command_complete, [:tag] defrecord :msg_empty_query, [] defrecord :msg_copy_data, [:data] defrecord :msg_copy_done, [] defrecord :msg_copy_fail, [:message] defrecord :msg_copy_in_response, [:format, :columns] defrecord :msg_copy_both_response, [:format, :columns] defrecord :msg_copy_out_response, [:format, :columns] defrecord :msg_terminate, [] defrecord :msg_ssl_request, [] defrecord :msg_cancel_request, [:pid, :key] defrecord :row_field, [:name, :table_oid, :column, :type_oid, :type_size, :type_mod, :format] ### decoders ### # auth def parse(<>, ?R, size) do type = decode_auth_type(type) data = case type do :md5 -> <> = rest data :gss_cont -> rest_size = size - 2 <> = rest data :sasl -> rest :sasl_cont -> rest :sasl_fin -> rest _ -> nil end msg_auth(type: type, data: data) end # backend_key def parse(<>, ?K, _size) do msg_backend_key(pid: pid, key: key) end # ready def parse(<>, ?Z, _size) do status = case status do ?I -> :idle ?T -> :transaction ?E -> :error end msg_ready(status: status) end # parameter_desc def parse(<>, ?t, _size) do oids = for <>, do: oid msg_parameter_desc(type_oids: oids) end def parse(<>, ?t, size) do len = div(size - 2, 4) case <> do <<^overflow_len::uint16()>> -> msg_too_many_parameters(len: len, max_len: 0xFFFF) _ -> raise "invalid parameter description" end end # row_desc def parse(<>, ?T, _size) do fields = decode_row_fields(rest, len) msg_row_desc(fields: fields) end # data_row def parse(<<_::uint16(), rest::binary>>, ?D, _size) do msg_data_row(values: rest) end # notify def parse(<>, ?A, _size) do {channel, rest} = decode_string(rest) {payload, ""} = decode_string(rest) msg_notify(pg_pid: pg_pid, channel: channel, payload: payload) end # error def parse(rest, ?E, _size) do fields = decode_fields(rest) msg_error(fields: Map.new(fields)) end # notice def parse(rest, ?N, _size) do fields = decode_fields(rest) msg_notice(fields: Map.new(fields)) end # parameter def parse(rest, ?S, _size) do {name, rest} = decode_string(rest) {value, ""} = decode_string(rest) msg_parameter(name: name, value: value) end # parse_complete def parse(_rest, ?1, _size) do msg_parse_complete() end # no_data def parse(_rest, ?n, _size) do msg_no_data() end # bind_complete def parse(_rest, ?2, _size) do msg_bind_complete() end # close_complete def parse(_rest, ?3, _size) do msg_close_complete() end # portal_suspended def parse(_rest, ?s, _size) do msg_portal_suspend() end # command_complete def parse(rest, ?C, _size) do {tag, ""} = decode_string(rest) msg_command_complete(tag: tag) end # empty_query def parse(_rest, ?I, _size) do msg_empty_query() end # msg_copy_data def parse(data, ?d, _size) do msg_copy_data(data: data) end # msg_copy_done def parse(_rest, ?c, _size) do msg_copy_done() end # msg_copy_fail def parse(message, ?f, _size) do msg_copy_fail(message: message) end # msg_copy_in_response def parse(rest, ?G, _size) do {format, columns} = decode_copy(rest) msg_copy_in_response(format: format, columns: columns) end # msg_copy_out_response def parse(rest, ?H, _size) do {format, columns} = decode_copy(rest) msg_copy_out_response(format: format, columns: columns) end # msg_copy_both_response def parse(rest, ?W, _size) do {format, columns} = decode_copy(rest) msg_copy_both_response(format: format, columns: columns) end ### encoders ### def encode_msg(msg) do {first, data} = encode(msg) size = IO.iodata_length(data) + 4 if size > 0xFFFFFFFF do raise ArgumentError, "payload size #{size} exceeds maximum size of #{0xFFFFFFFF}" end if first do [first, <>, data] else [<>, data] end end # startup defp encode(msg_startup(params: params)) do params = Enum.reduce(params, [], fn {key, value}, acc -> [acc, to_string(key), 0, value, 0] end) vsn = <<@protocol_vsn_major::int16(), @protocol_vsn_minor::int16()>> {nil, [vsn, params, 0]} end # password defp encode(msg_password(pass: pass)) do {?p, [pass]} end # query defp encode(msg_query(statement: statement)) do {?Q, [statement, 0]} end # parse defp encode(msg_parse(name: name, statement: statement, type_oids: oids)) do oids = for oid <- oids, into: "", do: <> len = <> {?P, [name, 0, statement, 0, len, oids]} end # describe defp encode(msg_describe(type: type, name: name)) do byte = case type do :statement -> ?S :portal -> ?P end {?D, [byte, name, 0]} end # flush defp encode(msg_flush()) do {?H, ""} end # close defp encode(msg_close(type: type, name: name)) do byte = case type do :statement -> ?S :portal -> ?P end {?C, [byte, name, 0]} end # bind defp encode( msg_bind( name_port: port, name_stat: stat, param_formats: param_formats, params: params, result_formats: result_formats ) ) do pfs = for format <- param_formats, into: "", do: <> rfs = for format <- result_formats, into: "", do: <> len_pfs = <> len_rfs = <> len_params = <> {?B, [port, 0, stat, 0, len_pfs, pfs, len_params, params, len_rfs, rfs]} end # execute defp encode(msg_execute(name_port: port, max_rows: rows)) do {?E, [port, 0, <>]} end # sync defp encode(msg_sync()) do {?S, ""} end # terminate defp encode(msg_terminate()) do {?X, ""} end # ssl_request defp encode(msg_ssl_request()) do {nil, <<1234::int16(), 5679::int16()>>} end # cancel_request defp encode(msg_cancel_request(pid: pid, key: key)) do {nil, <<1234::int16(), 5678::int16(), pid::int32(), key::int32()>>} end # copy_data defp encode(msg_copy_data(data: data)) do {?d, data} end # copy_done defp encode(msg_copy_done()) do {?c, ""} end # copy_fail defp encode(msg_copy_fail(message: msg)) do {?f, [msg, 0]} end ### encode helpers ### defp format(:text), do: 0 defp format(:binary), do: 1 ### decode helpers ### defp decode_fields(<<0>>), do: [] defp decode_fields(<>) do type = decode_field_type(field) {string, rest} = decode_string(rest) [{type, string} | decode_fields(rest)] end defp decode_string(bin) do {pos, 1} = :binary.match(bin, <<0>>) {string, <<0, rest::binary>>} = :erlang.split_binary(bin, pos) {string, rest} end defp decode_row_fields("", 0), do: [] defp decode_row_fields(rest, count) do {field, rest} = decode_row_field(rest) [field | decode_row_fields(rest, count - 1)] end defp decode_row_field(rest) do {name, rest} = decode_string(rest) <> = rest field = row_field( name: name, table_oid: table_oid, column: column, type_oid: type_oid, type_size: type_size, type_mod: type_mod, format: format ) {field, rest} end Enum.each(@auth_types, fn {type, value} -> def decode_auth_type(unquote(value)), do: unquote(type) end) Enum.each(@error_fields, fn {field, char} -> def decode_field_type(unquote(char)), do: unquote(field) end) def decode_field_type(_), do: :unknown defp decode_format(0), do: :text defp decode_format(1), do: :binary defp decode_copy(<>) do format = decode_format(format) columns = for <>, do: decode_format(column) {format, columns} end end