1use std::collections::VecDeque;
3use std::fmt::Debug;
4use std::io::{Cursor, ErrorKind};
5#[cfg(test)]
6use std::net::{SocketAddr, TcpListener, TcpStream};
7use std::option::Option;
8
9use bson::doc;
11use bson::document::ValueAccessError;
12
13use crate::byte_counter::ByteCounter;
14
15#[derive(Debug, Eq, PartialEq)]
17pub enum ErrorCode {
18 BsonParseFailed,
20 MessageTooBig,
23 MessageParseFailed,
25 MessageVersionIncompatible,
27 SectionIdUnknown,
29 SectionParseFailed,
32 RequestCookieInvalid,
34 RequestNamespaceInvalid,
36 RequestFunctionInvalid,
38 RequestVersionInvalid,
40 ResponseCookieInvalid,
42 ResponseStateInvalid,
44 Runtime(i32),
46 Unknown(i32),
48}
49
50#[derive(thiserror::Error, Debug)]
52pub enum Error {
53 #[error("failed to read data from read stream")]
55 ReaderReadFailed(#[source] std::io::Error),
56
57 #[error("received invalid bson document size header value of {0}, must be at least 4")]
59 BsonDocumentSizeTooSmall(i32),
60
61 #[error("received invalid bson document size header value of {0}, must be less than {1}")]
63 BsonDocumentSizeTooLarge(i32, i32),
64
65 #[error("waited longer than {} seconds for read", .0.as_secs_f32())]
67 MessageReadTimedOut(std::time::Duration),
68
69 #[error("failed to parse bson Message document")]
71 BsonDocumentParseFailed(#[source] bson::de::Error),
72
73 #[error("failed to convert bson document to Message")]
75 MessageConversionFailed(#[source] crate::honk_rpc::ErrorCode),
76
77 #[error("failed to serialize bson document")]
79 BsonWriteFailed(#[source] bson::ser::Error),
80
81 #[error("failed to write data to write stream")]
83 WriterWriteFailed(#[source] std::io::Error),
84
85 #[error("failed to flush message to write stream")]
87 WriterFlushFailed(#[source] std::io::Error),
88
89 #[error("recieved error section without cookie")]
91 UnknownErrorSectionReceived(#[source] crate::honk_rpc::ErrorCode),
92
93 #[error(
95 "tried to set invalid max message size; must be >=5 bytes and <= i32::MAX (2147483647)"
96 )]
97 InvalidMaxMesageSize(),
98
99 #[error("queued message section is too large to write; calculated size is {0} but must be less than {1}")]
101 SectionTooLarge(usize, usize),
102}
103
104impl From<i32> for ErrorCode {
105 fn from(value: i32) -> ErrorCode {
106 match value {
107 -1i32 => ErrorCode::BsonParseFailed,
108 -2i32 => ErrorCode::MessageTooBig,
109 -3i32 => ErrorCode::MessageParseFailed,
110 -4i32 => ErrorCode::MessageVersionIncompatible,
111 -5i32 => ErrorCode::SectionIdUnknown,
112 -6i32 => ErrorCode::SectionParseFailed,
113 -7i32 => ErrorCode::RequestCookieInvalid,
114 -8i32 => ErrorCode::RequestNamespaceInvalid,
115 -9i32 => ErrorCode::RequestFunctionInvalid,
116 -10i32 => ErrorCode::RequestVersionInvalid,
117 -11i32 => ErrorCode::ResponseCookieInvalid,
118 -12i32 => ErrorCode::ResponseStateInvalid,
119 value => {
120 if value > 0 {
121 ErrorCode::Runtime(value)
122 } else {
123 ErrorCode::Unknown(value)
124 }
125 }
126 }
127 }
128}
129
130impl From<ErrorCode> for i32 {
131 fn from(err: ErrorCode) -> Self {
132 match err {
133 ErrorCode::BsonParseFailed => -1i32,
134 ErrorCode::MessageTooBig => -2i32,
135 ErrorCode::MessageParseFailed => -3i32,
136 ErrorCode::MessageVersionIncompatible => -4i32,
137 ErrorCode::SectionIdUnknown => -5i32,
138 ErrorCode::SectionParseFailed => -6i32,
139 ErrorCode::RequestCookieInvalid => -7i32,
140 ErrorCode::RequestNamespaceInvalid => -8i32,
141 ErrorCode::RequestFunctionInvalid => -9i32,
142 ErrorCode::RequestVersionInvalid => -10i32,
143 ErrorCode::ResponseCookieInvalid => -11i32,
144 ErrorCode::ResponseStateInvalid => -12i32,
145 ErrorCode::Runtime(val) => val,
146 ErrorCode::Unknown(val) => val,
147 }
148 }
149}
150
151impl std::fmt::Display for ErrorCode {
152 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
153 match self {
154 ErrorCode::BsonParseFailed => write!(f, "ProtocolError: failed to parse BSON object"),
155 ErrorCode::MessageTooBig => write!(f, "ProtocolError: received document too large"),
156 ErrorCode::MessageParseFailed => {
157 write!(f, "ProtocolError: received message has invalid schema")
158 }
159 ErrorCode::MessageVersionIncompatible => write!(
160 f,
161 "ProtocolError: received message has incompatible version"
162 ),
163 ErrorCode::SectionIdUnknown => write!(
164 f,
165 "ProtocolError: received message contains section of unknown type"
166 ),
167 ErrorCode::SectionParseFailed => write!(
168 f,
169 "ProtocolError: received message contains section with invalid schema"
170 ),
171 ErrorCode::RequestCookieInvalid => {
172 write!(f, "ProtocolError: request cookie already in use")
173 }
174 ErrorCode::RequestNamespaceInvalid => write!(
175 f,
176 "ProtocolError: request function does not exist in requested namespace"
177 ),
178 ErrorCode::RequestFunctionInvalid => {
179 write!(f, "ProtocolError: request function does not exist")
180 }
181 ErrorCode::RequestVersionInvalid => {
182 write!(f, "ProtocolError: request function version does not exist")
183 }
184 ErrorCode::ResponseCookieInvalid => {
185 write!(f, "ProtocolError: response cookie is not recognized")
186 }
187 ErrorCode::ResponseStateInvalid => write!(f, "ProtocolError: response state not valid"),
188 ErrorCode::Runtime(code) => write!(f, "RuntimeError: runtime error {}", code),
189 ErrorCode::Unknown(code) => write!(f, "UnknownError: unknown error code {}", code),
190 }
191 }
192}
193
194impl std::error::Error for ErrorCode {}
195
196const fn semver_to_i32(major: u8, minor: u8, patch: u8) -> i32 {
198 let major = major as i32;
199 let minor = minor as i32;
200 let patch = patch as i32;
201 (major << 16) | (minor << 8) | patch
202}
203
204const fn i32_to_semver(ver: i32) -> Option<(u8, u8, u8)> {
205 if ver >= 0 && ver <= 0xffffff {
206 let major = (ver & 0xff0000) >> 16;
207 let minor = (ver & 0xff00) >> 8;
208 let patch = ver & 0xff;
209 Some((major as u8, minor as u8, patch as u8))
210 } else {
211 None
212 }
213}
214
215const HONK_RPC_VERSION: i32 = semver_to_i32(0, 1, 0);
217
218struct Message {
219 honk_rpc: i32,
220 sections: Vec<Section>,
221}
222
223impl TryFrom<bson::document::Document> for Message {
224 type Error = ErrorCode;
225
226 fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
227 let mut value = value;
228
229 let honk_rpc = match value.get_i32("honk_rpc") {
231 Ok(HONK_RPC_VERSION) => HONK_RPC_VERSION,
232 Ok(honk_rpc) => {
233 return if let Some(_version) = i32_to_semver(honk_rpc) {
234 Err(ErrorCode::MessageVersionIncompatible)
236 } else {
237 Err(ErrorCode::MessageParseFailed)
239 };
240 }
241 Err(_err) => return Err(ErrorCode::MessageParseFailed),
242 };
243
244 if let Ok(sections) = value.get_array_mut("sections") {
245 if sections.is_empty() {
247 return Err(ErrorCode::MessageParseFailed);
248 }
249
250 let mut message = Message {
251 honk_rpc,
252 sections: Default::default(),
253 };
254
255 for section in sections.iter_mut() {
256 if let bson::Bson::Document(section) = std::mem::take(section) {
257 message.sections.push(Section::try_from(section)?);
258 } else {
259 return Err(ErrorCode::SectionParseFailed);
260 }
261 }
262 Ok(message)
263 } else {
264 Err(ErrorCode::MessageParseFailed)
265 }
266 }
267}
268
269impl From<Message> for bson::document::Document {
270 fn from(value: Message) -> bson::document::Document {
271 let mut value = value;
272 let mut message = bson::document::Document::new();
273 message.insert("honk_rpc", value.honk_rpc);
274
275 let mut sections = bson::Array::new();
276 for section in value.sections.drain(0..) {
277 sections.push(bson::Bson::Document(bson::document::Document::from(
278 section,
279 )));
280 }
281 message.insert("sections", sections);
282
283 message
284 }
285}
286
287pub type RequestCookie = i64;
289
290const ERROR_SECTION_ID: i32 = 0i32;
291const REQUEST_SECTION_ID: i32 = 1i32;
292const RESPONSE_SECTION_ID: i32 = 2i32;
293
294enum Section {
295 Error(ErrorSection),
296 Request(RequestSection),
297 Response(ResponseSection),
298}
299
300struct ErrorSection {
301 cookie: Option<RequestCookie>,
302 code: ErrorCode,
303 message: Option<String>,
304 data: Option<bson::Bson>,
305}
306
307struct RequestSection {
308 cookie: Option<RequestCookie>,
309 namespace: String,
310 function: String,
311 version: i32,
312 arguments: bson::document::Document,
313}
314
315#[repr(i32)]
316#[derive(Debug, PartialEq)]
317enum RequestState {
318 Pending = 0i32,
319 Complete = 1i32,
320}
321
322struct ResponseSection {
323 cookie: RequestCookie,
324 state: RequestState,
325 result: Option<bson::Bson>,
326}
327
328impl TryFrom<bson::document::Document> for Section {
329 type Error = ErrorCode;
330
331 fn try_from(
332 value: bson::document::Document,
333 ) -> Result<Self, <Self as TryFrom<bson::document::Document>>::Error> {
334 match value.get_i32("id") {
335 Ok(ERROR_SECTION_ID) => Ok(Section::Error(ErrorSection::try_from(value)?)),
336 Ok(REQUEST_SECTION_ID) => Ok(Section::Request(RequestSection::try_from(value)?)),
337 Ok(RESPONSE_SECTION_ID) => Ok(Section::Response(ResponseSection::try_from(value)?)),
338 Ok(_) => Err(ErrorCode::SectionIdUnknown),
339 Err(_) => Err(ErrorCode::SectionParseFailed),
340 }
341 }
342}
343
344impl From<Section> for bson::document::Document {
345 fn from(value: Section) -> bson::document::Document {
346 match value {
347 Section::Error(section) => bson::document::Document::from(section),
348 Section::Request(section) => bson::document::Document::from(section),
349 Section::Response(section) => bson::document::Document::from(section),
350 }
351 }
352}
353
354impl TryFrom<bson::document::Document> for ErrorSection {
355 type Error = ErrorCode;
356
357 fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
358 let mut value = value;
359
360 let cookie = match value.get_i64("cookie") {
361 Ok(cookie) => Some(cookie),
362 Err(ValueAccessError::NotPresent) => None,
363 Err(_) => return Err(ErrorCode::SectionParseFailed),
364 };
365
366 let code = match value.get_i32("code") {
367 Ok(code) => ErrorCode::from(code),
368 Err(_) => return Err(ErrorCode::SectionParseFailed),
369 };
370
371 let message = match value.get_str("message") {
372 Ok(message) => Some(message.to_string()),
373 Err(ValueAccessError::NotPresent) => None,
374 Err(_) => return Err(ErrorCode::SectionParseFailed),
375 };
376
377 let data = value.get_mut("data").map(std::mem::take);
378
379 Ok(ErrorSection {
380 cookie,
381 code,
382 message,
383 data,
384 })
385 }
386}
387
388impl From<ErrorSection> for bson::document::Document {
389 fn from(value: ErrorSection) -> bson::document::Document {
390 let mut error_section = bson::document::Document::new();
391 error_section.insert("id", ERROR_SECTION_ID);
392
393 if let Some(cookie) = value.cookie {
394 error_section.insert("cookie", cookie);
395 }
396
397 error_section.insert("code", Into::<i32>::into(value.code));
398
399 if let Some(message) = value.message {
400 error_section.insert("message", message);
401 }
402
403 if let Some(data) = value.data {
404 error_section.insert("data", data);
405 }
406
407 error_section
408 }
409}
410
411impl TryFrom<bson::document::Document> for RequestSection {
412 type Error = ErrorCode;
413
414 fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
415 let mut value = value;
416
417 let cookie = match value.get_i64("cookie") {
418 Ok(cookie) => Some(cookie),
419 Err(ValueAccessError::NotPresent) => None,
420 Err(_) => return Err(ErrorCode::SectionParseFailed),
421 };
422
423 let namespace = match value.get_str("namespace") {
424 Ok(namespace) => namespace.to_string(),
425 Err(ValueAccessError::NotPresent) => String::default(),
426 Err(_) => return Err(ErrorCode::SectionParseFailed),
427 };
428
429 let function = match value.get_str("function") {
430 Ok(function) => {
431 if function.is_empty() {
432 return Err(ErrorCode::RequestFunctionInvalid);
433 } else {
434 function.to_string()
435 }
436 }
437 Err(_) => return Err(ErrorCode::SectionParseFailed),
438 };
439
440 let version = match value.get_i32("version") {
441 Ok(version) => version,
442 Err(ValueAccessError::NotPresent) => 0i32,
443 Err(_) => return Err(ErrorCode::SectionParseFailed),
444 };
445
446 let arguments = match value.get_document_mut("arguments") {
447 Ok(arguments) => std::mem::take(arguments),
448 Err(ValueAccessError::NotPresent) => bson::document::Document::new(),
449 Err(_) => return Err(ErrorCode::SectionParseFailed),
450 };
451
452 Ok(RequestSection {
453 cookie,
454 namespace,
455 function,
456 version,
457 arguments,
458 })
459 }
460}
461
462impl From<RequestSection> for bson::document::Document {
463 fn from(value: RequestSection) -> bson::document::Document {
464 let mut request_section = bson::document::Document::new();
465 request_section.insert("id", REQUEST_SECTION_ID);
466
467 if let Some(cookie) = value.cookie {
468 request_section.insert("cookie", cookie);
469 }
470
471 if !value.namespace.is_empty() {
472 request_section.insert("namespace", value.namespace);
473 }
474
475 request_section.insert("function", value.function);
476
477 if value.version != 0i32 {
478 request_section.insert("version", value.version);
479 }
480
481 request_section.insert("arguments", value.arguments);
482
483 request_section
484 }
485}
486
487impl TryFrom<bson::document::Document> for ResponseSection {
488 type Error = ErrorCode;
489
490 fn try_from(value: bson::document::Document) -> Result<Self, Self::Error> {
491 let mut value = value;
492 let cookie = match value.get_i64("cookie") {
493 Ok(cookie) => cookie,
494 Err(_) => return Err(ErrorCode::SectionParseFailed),
495 };
496
497 let state = match value.get_i32("state") {
498 Ok(0i32) => RequestState::Pending,
499 Ok(1i32) => RequestState::Complete,
500 Ok(_) => return Err(ErrorCode::ResponseStateInvalid),
501 Err(_) => return Err(ErrorCode::SectionParseFailed),
502 };
503
504 let result = value.get_mut("result").map(std::mem::take);
505
506 if state == RequestState::Pending && result.is_some() {
508 return Err(ErrorCode::SectionParseFailed);
509 }
510
511 Ok(ResponseSection {
512 cookie,
513 state,
514 result,
515 })
516 }
517}
518
519impl From<ResponseSection> for bson::document::Document {
520 fn from(value: ResponseSection) -> bson::document::Document {
521 let mut response_section = bson::document::Document::new();
522 response_section.insert("id", RESPONSE_SECTION_ID);
523
524 response_section.insert("cookie", value.cookie);
525 response_section.insert("state", value.state as i32);
526
527 if let Some(result) = value.result {
528 response_section.insert("result", result);
529 }
530
531 response_section
532 }
533}
534
535pub trait ApiSet {
627 fn namespace(&self) -> &str;
629
630 fn exec_function(
644 &mut self,
645 name: &str,
646 version: i32,
647 args: bson::document::Document,
648 request_cookie: Option<RequestCookie>,
649 ) -> Option<Result<Option<bson::Bson>, ErrorCode>>;
650
651 fn update(&mut self) {}
656
657 fn next_result(&mut self) -> Option<(RequestCookie, Result<Option<bson::Bson>, ErrorCode>)> {
665 None
666 }
667}
668
669pub enum Response {
671 Pending {
673 cookie: RequestCookie,
675 },
676 Success {
678 cookie: RequestCookie,
680 result: Option<bson::Bson>,
682 },
683 Error {
685 cookie: RequestCookie,
687 error_code: ErrorCode,
689 },
690}
691
692pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 4 * 1024;
695pub const DEFAULT_MAX_WAIT_TIME: std::time::Duration = std::time::Duration::from_secs(60);
697
698const HEADER_SIZE: usize = 4usize;
701const HONK_RPC_SIZE: usize = 14usize;
703const SECTIONS_SIZE: usize = 18usize;
705const FOOTER_SIZE: usize = 1usize;
707
708const MIN_MESSAGE_SIZE: usize = HEADER_SIZE + HONK_RPC_SIZE + SECTIONS_SIZE + FOOTER_SIZE;
710
711pub fn get_message_overhead() -> Result<usize, Error> {
715 let message = doc! {
718 "honk_rpc" : HONK_RPC_VERSION,
719 "sections" : [
720 bson::Bson::Null
721 ]
722 };
723
724 let mut counter: ByteCounter = Default::default();
725 message
726 .to_writer(&mut counter)
727 .map_err(Error::BsonWriteFailed)?;
728
729 Ok(counter.bytes())
730}
731
732pub fn get_error_section_size(
737 cookie: Option<RequestCookie>,
738 message: Option<String>,
739 data: Option<bson::Bson>,
740) -> Result<usize, Error> {
741 let mut error_section = doc! {
742 "id": ERROR_SECTION_ID,
743 "code": Into::<i32>::into(ErrorCode::Unknown(0)),
744 };
745
746 if let Some(cookie) = cookie {
747 error_section.insert("cookie", bson::Bson::Int64(cookie));
748 }
749
750 if let Some(message) = message {
751 error_section.insert("message", bson::Bson::String(message));
752 }
753
754 if let Some(data) = data {
755 error_section.insert("data", data);
756 }
757
758 let mut counter: ByteCounter = Default::default();
759 error_section
760 .to_writer(&mut counter)
761 .map_err(Error::BsonWriteFailed)?;
762
763 Ok(counter.bytes())
764}
765
766pub fn get_request_section_size(
771 cookie: Option<RequestCookie>,
772 namespace: Option<String>,
773 function: String,
774 version: Option<i32>,
775 arguments: Option<bson::Document>,
776) -> Result<usize, Error> {
777 let mut request_section = doc! {
778 "id": REQUEST_SECTION_ID,
779 "function": bson::Bson::String(function),
780 };
781
782 if let Some(cookie) = cookie {
783 request_section.insert("cookie", bson::Bson::Int64(cookie));
784 }
785
786 if let Some(namespace) = namespace {
787 request_section.insert("namespace", bson::Bson::String(namespace));
788 }
789
790 if let Some(version) = version {
791 request_section.insert("version", bson::Bson::Int32(version));
792 }
793
794 if let Some(arguments) = arguments {
795 request_section.insert("arguments", arguments);
796 }
797
798 let mut counter: ByteCounter = Default::default();
799 request_section
800 .to_writer(&mut counter)
801 .map_err(Error::BsonWriteFailed)?;
802
803 Ok(counter.bytes())
804}
805
806pub fn get_response_section_size(result: Option<bson::Bson>) -> Result<usize, Error> {
811 let mut response_section = doc! {
812 "id": RESPONSE_SECTION_ID,
813 "cookie": bson::Bson::Int64(0),
814 "state": bson::Bson::Int32(0),
815 };
816
817 if let Some(result) = result {
818 response_section.insert("result", result);
819 }
820
821 let mut counter: ByteCounter = Default::default();
822 response_section
823 .to_writer(&mut counter)
824 .map_err(Error::BsonWriteFailed)?;
825
826 Ok(counter.bytes())
827}
828
829pub struct Session<RW> {
834 stream: RW,
836 message_write_buffer: VecDeque<u8>,
838
839 remaining_byte_count: Option<usize>,
844 message_read_buffer: Vec<u8>,
846 pending_sections: VecDeque<Section>,
848 inbound_requests: Vec<RequestSection>,
850 inbound_responses: VecDeque<Response>,
852
853 message_serialization_buffer: VecDeque<u8>,
857 next_cookie: RequestCookie,
859 outbound_sections: Vec<bson::Document>,
861
862 max_message_size: usize,
864 max_wait_time: std::time::Duration,
867 read_timestamp: std::time::Instant,
869}
870
871#[allow(dead_code)]
872impl<RW> Session<RW>
873where
874 RW: std::io::Read + std::io::Write + Send,
875{
876 pub fn set_max_message_size(&mut self, max_message_size: i32) -> Result<(), Error> {
878 if max_message_size < MIN_MESSAGE_SIZE as i32 {
879 Err(Error::InvalidMaxMesageSize())
881 } else {
882 self.max_message_size = max_message_size as usize;
883 Ok(())
884 }
885 }
886
887 pub fn get_max_message_size(&self) -> usize {
889 self.max_message_size
890 }
891
892 pub fn set_max_wait_time(&mut self, max_wait_time: std::time::Duration) {
894 self.max_wait_time = max_wait_time;
895 }
896
897 pub fn get_max_wait_time(&self) -> std::time::Duration {
899 self.max_wait_time
900 }
901
902 pub fn new(stream: RW) -> Self {
904 let mut message_write_buffer: VecDeque<u8> = Default::default();
905 message_write_buffer.reserve(DEFAULT_MAX_MESSAGE_SIZE);
906
907 let mut message_serialization_buffer: VecDeque<u8> = Default::default();
908 message_serialization_buffer.reserve(DEFAULT_MAX_MESSAGE_SIZE);
909
910 Session {
911 stream,
912 message_write_buffer,
913 remaining_byte_count: None,
914 message_read_buffer: Default::default(),
915 pending_sections: Default::default(),
916 inbound_requests: Default::default(),
917 inbound_responses: Default::default(),
918 message_serialization_buffer,
919 next_cookie: Default::default(),
920 outbound_sections: Default::default(),
921 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
922 max_wait_time: DEFAULT_MAX_WAIT_TIME,
923 read_timestamp: std::time::Instant::now(),
924 }
925 }
926
927 pub fn into_stream(self) -> RW {
929 self.stream
930 }
931
932 fn stream_read(&mut self, buffer: &mut [u8]) -> Result<usize, Error> {
934 match self.stream.read(buffer) {
935 Err(err) => {
936 if err.kind() == ErrorKind::WouldBlock || err.kind() == ErrorKind::TimedOut {
937 if std::time::Instant::now().duration_since(self.read_timestamp)
939 > self.max_wait_time
940 {
941 Err(Error::MessageReadTimedOut(self.max_wait_time))
942 } else {
943 Ok(0)
944 }
945 } else {
946 Err(Error::ReaderReadFailed(err))
947 }
948 }
949 Ok(0) => Err(Error::ReaderReadFailed(std::io::Error::from(
950 ErrorKind::UnexpectedEof,
951 ))),
952 Ok(count) => {
953 self.read_timestamp = std::time::Instant::now();
955 Ok(count)
956 }
957 }
958 }
959
960 fn read_message_size(&mut self) -> Result<(), Error> {
962 match self.remaining_byte_count {
963 Some(_remaining) => Ok(()),
965 None => {
967 assert!(self.message_read_buffer.len() < std::mem::size_of::<i32>());
969 let bytes_needed = std::mem::size_of::<i32>() - self.message_read_buffer.len();
970 let mut buffer = [0u8; std::mem::size_of::<i32>()];
972 let buffer = &mut buffer[0..bytes_needed];
974 match self.stream_read(buffer) {
975 Err(err) => Err(err),
976 Ok(0) => Ok(()),
977 Ok(count) => {
978 #[cfg(test)]
979 println!("<<< read {} bytes for message header", count);
980 self.message_read_buffer
981 .extend_from_slice(&buffer[0..count]);
982
983 if self.message_read_buffer.len() == std::mem::size_of::<i32>() {
985 let size = &self.message_read_buffer.as_slice();
986 let size: i32 = (size[0] as i32)
987 | (size[1] as i32) << 8
988 | (size[2] as i32) << 16
989 | (size[3] as i32) << 24;
990 if size <= std::mem::size_of::<i32>() as i32 {
992 return Err(Error::BsonDocumentSizeTooSmall(size));
993 }
994 if size as usize > self.max_message_size {
996 return Err(Error::BsonDocumentSizeTooLarge(
997 size,
998 self.max_message_size as i32,
999 ));
1000 }
1001
1002 let size = size as usize - std::mem::size_of::<i32>();
1004
1005 self.remaining_byte_count = Some(size);
1006 }
1007 Ok(())
1008 }
1009 }
1010 }
1011 }
1012 }
1013
1014 fn read_message(&mut self) -> Result<Option<Message>, Error> {
1016 self.read_message_size()?;
1018 if let Some(remaining) = self.remaining_byte_count {
1020 #[cfg(test)]
1021 println!("--- message requires {} more bytes", remaining);
1022
1023 let mut buffer = vec![0u8; remaining];
1024 match self.stream_read(&mut buffer) {
1025 Err(err) => Err(err),
1026 Ok(0) => Ok(None),
1027 Ok(count) => {
1028 #[cfg(test)]
1029 println!("<<< read {} bytes", count);
1030 self.message_read_buffer
1032 .extend_from_slice(&buffer[0..count]);
1033 if remaining == count {
1034 self.remaining_byte_count = None;
1035
1036 let mut cursor = Cursor::new(std::mem::take(&mut self.message_read_buffer));
1037 let bson = bson::document::Document::from_reader(&mut cursor)
1038 .map_err(Error::BsonDocumentParseFailed)?;
1039
1040 self.message_read_buffer = cursor.into_inner();
1042 self.message_read_buffer.clear();
1043
1044 #[cfg(test)]
1045 println!("<<< read message: {}", bson);
1046
1047 Ok(Some(
1048 Message::try_from(bson).map_err(Error::MessageConversionFailed)?,
1049 ))
1050 } else {
1051 self.remaining_byte_count = Some(remaining - count);
1053 Ok(None)
1054 }
1055 }
1056 }
1057 } else {
1058 Ok(None)
1059 }
1060 }
1061
1062 fn read_sections(&mut self) -> Result<(), Error> {
1064 loop {
1065 match self.read_message() {
1066 Ok(Some(mut message)) => {
1067 self.pending_sections.extend(message.sections.drain(..));
1068 }
1069 Ok(None) => return Ok(()),
1070 Err(err) => {
1071 match err {
1072 Error::MessageReadTimedOut(_) | Error::ReaderReadFailed(_) => {
1074 if self.pending_sections.is_empty() && self.inbound_responses.is_empty()
1076 {
1077 return Err(err);
1078 }
1079 return Ok(());
1080 }
1081 _ => return Err(err),
1083 }
1084 }
1085 }
1086 }
1087 }
1088
1089 fn process_sections(&mut self) -> Result<(), Error> {
1091 while let Some(section) = self.pending_sections.pop_front() {
1092 match section {
1093 Section::Error(error) => {
1094 if let Some(cookie) = error.cookie {
1095 self.inbound_responses.push_back(Response::Error {
1097 cookie,
1098 error_code: error.code,
1099 });
1100 } else {
1101 return Err(Error::UnknownErrorSectionReceived(error.code));
1102 }
1103 }
1104 Section::Request(request) => {
1105 self.inbound_requests.push(request);
1107 }
1108 Section::Response(response) => {
1109 match (response.cookie, response.state, response.result) {
1112 (cookie, RequestState::Complete, result) => {
1113 self.inbound_responses
1114 .push_back(Response::Success { cookie, result });
1115 }
1116 (cookie, RequestState::Pending, _) => {
1117 self.inbound_responses
1118 .push_back(Response::Pending { cookie });
1119 }
1120 }
1121 }
1122 }
1123 }
1124
1125 Ok(())
1126 }
1127
1128 fn push_outbound_section(&mut self, section: Section) -> Result<(), Error> {
1130 let max_section_size = self.max_message_size - MIN_MESSAGE_SIZE;
1131
1132 let mut counter: ByteCounter = Default::default();
1133 let section: bson::Document = section.into();
1134 section
1135 .to_writer(&mut counter)
1136 .map_err(Error::BsonWriteFailed)?;
1137 let section_size = counter.bytes();
1138
1139 if section_size <= max_section_size {
1140 self.outbound_sections.push(section);
1141 Ok(())
1142 } else {
1143 Err(Error::SectionTooLarge(section_size, max_section_size))
1144 }
1145 }
1146
1147 fn serialize_messages(&mut self) -> Result<(), Error> {
1149 if self.outbound_sections.is_empty() {
1151 return Ok(());
1152 }
1153
1154 let message = Message {
1156 honk_rpc: HONK_RPC_VERSION,
1157 sections: Default::default(),
1158 };
1159 let mut message = bson::document::Document::from(message);
1160 message.insert("sections", std::mem::take(&mut self.outbound_sections));
1161 self.serialize_messages_impl(message)
1162 }
1163
1164 fn serialize_messages_impl(
1166 &mut self,
1167 mut message: bson::document::Document,
1168 ) -> Result<(), Error> {
1169 self.message_serialization_buffer.clear();
1170 message
1171 .to_writer(&mut self.message_serialization_buffer)
1172 .map_err(Error::BsonWriteFailed)?;
1173
1174 if self.message_serialization_buffer.len() > self.max_message_size {
1175 let sections = message.get_array_mut("sections").unwrap();
1177 assert!(sections.len() > 1);
1178
1179 let right = doc! {
1180 "honk_rpc" : HONK_RPC_VERSION,
1181 "sections" : sections.split_off(sections.len() / 2),
1182 };
1183 let left = message;
1184
1185 self.serialize_messages_impl(left)?;
1186 self.serialize_messages_impl(right)?;
1187 } else {
1188 #[cfg(test)]
1189 println!(">>> write message: {:?}", message);
1190 self.message_write_buffer
1192 .append(&mut self.message_serialization_buffer);
1193 }
1194
1195 Ok(())
1196 }
1197
1198 fn write_pending_data(&mut self) -> Result<(), Error> {
1200 let bytes_written = self.write_pending_data_impl()?;
1201 self.stream.flush().map_err(Error::WriterWriteFailed)?;
1202 self.message_write_buffer.drain(0..bytes_written);
1204 self.message_write_buffer.make_contiguous();
1206
1207 Ok(())
1208 }
1209
1210 fn write_pending_data_impl(&mut self) -> Result<usize, Error> {
1211 let (mut pending_data, empty): (&[u8], &[u8]) = self.message_write_buffer.as_slices();
1213 assert!(empty.is_empty());
1214 let pending_bytes: usize = pending_data.len();
1215 let mut bytes_written: usize = 0usize;
1216
1217 while bytes_written != pending_bytes {
1218 match self.stream.write(pending_data) {
1219 Err(err) => {
1220 let kind = err.kind();
1221 if kind == ErrorKind::WouldBlock || kind == ErrorKind::TimedOut {
1222 return Ok(bytes_written);
1224 } else {
1225 return Err(Error::WriterWriteFailed(err));
1226 }
1227 }
1228 Ok(count) => {
1229 bytes_written += count;
1230 #[cfg(test)]
1231 println!(">>> sent {} of {} bytes", bytes_written, pending_bytes);
1232 pending_data = &pending_data[count..];
1233 }
1234 }
1235 }
1236
1237 Ok(bytes_written)
1238 }
1239
1240 pub fn update(&mut self, apisets: Option<&mut [&mut dyn ApiSet]>) -> Result<(), Error> {
1242 self.read_sections()?;
1244 self.process_sections()?;
1246
1247 let apisets = apisets.unwrap_or(&mut []);
1249 self.handle_requests(apisets)?;
1250
1251 self.serialize_messages()?;
1253
1254 self.write_pending_data()?;
1256
1257 Ok(())
1258 }
1259
1260 fn handle_requests(&mut self, apisets: &mut [&mut dyn ApiSet]) -> Result<(), Error> {
1262 let mut inbound_requests = std::mem::take(&mut self.inbound_requests);
1264 for mut request in inbound_requests.drain(..) {
1265 if let Ok(idx) =
1266 apisets.binary_search_by(|probe| probe.namespace().cmp(&request.namespace))
1267 {
1268 let apiset = match apisets.get_mut(idx) {
1269 Some(apiset) => apiset,
1270 None => unreachable!(),
1271 };
1272 match apiset.exec_function(
1273 &request.function,
1274 request.version,
1275 std::mem::take(&mut request.arguments),
1276 request.cookie,
1277 ) {
1278 Some(Ok(result)) => {
1280 if let Some(cookie) = request.cookie {
1281 self.push_outbound_section(Section::Response(ResponseSection {
1282 cookie,
1283 state: RequestState::Complete,
1284 result,
1285 }))?;
1286 }
1287 }
1288 Some(Err(error_code)) => {
1290 self.push_outbound_section(Section::Error(ErrorSection {
1291 cookie: request.cookie,
1292 code: error_code,
1293 message: None,
1294 data: None,
1295 }))?;
1296 }
1297 None => {
1299 if let Some(cookie) = request.cookie {
1300 self.push_outbound_section(Section::Response(ResponseSection {
1301 cookie,
1302 state: RequestState::Pending,
1303 result: None,
1304 }))?;
1305 }
1306 }
1307 }
1308 } else {
1309 self.push_outbound_section(Section::Error(ErrorSection {
1311 cookie: request.cookie,
1312 code: ErrorCode::RequestNamespaceInvalid,
1313 message: None,
1314 data: None,
1315 }))?;
1316 }
1317 }
1318
1319 for apiset in apisets.iter_mut() {
1321 apiset.update();
1323 while let Some((cookie, result)) = apiset.next_result() {
1325 match (cookie, result) {
1326 (cookie, Ok(result)) => {
1328 self.push_outbound_section(Section::Response(ResponseSection {
1329 cookie,
1330 state: RequestState::Complete,
1331 result,
1332 }))?;
1333 }
1334 (cookie, Err(error_code)) => {
1336 self.push_outbound_section(Section::Error(ErrorSection {
1337 cookie: Some(cookie),
1338 code: error_code,
1339 message: None,
1340 data: None,
1341 }))?;
1342 }
1343 }
1344 }
1345 }
1346 Ok(())
1347 }
1348
1349 pub fn client_call(
1351 &mut self,
1352 namespace: &str,
1353 function: &str,
1354 version: i32,
1355 arguments: bson::document::Document,
1356 ) -> Result<RequestCookie, Error> {
1357 let cookie = self.next_cookie;
1359 self.next_cookie += 1;
1360
1361 self.push_outbound_section(Section::Request(RequestSection {
1363 cookie: Some(cookie),
1364 namespace: namespace.to_string(),
1365 function: function.to_string(),
1366 version,
1367 arguments,
1368 }))?;
1369
1370 Ok(cookie)
1371 }
1372
1373 pub fn client_drain_responses(&mut self) -> std::collections::vec_deque::Drain<Response> {
1375 self.inbound_responses.drain(..)
1376 }
1377
1378 pub fn client_next_response(&mut self) -> Option<Response> {
1380 self.inbound_responses.pop_front()
1381 }
1382}
1383
1384#[test]
1385fn test_honk_client_read_write() -> anyhow::Result<()> {
1386 let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0u16));
1387 let listener = TcpListener::bind(socket_addr)?;
1388 let socket_addr = listener.local_addr()?;
1389
1390 let stream1 = TcpStream::connect(socket_addr)?;
1391 stream1.set_nonblocking(true)?;
1392 let (stream2, _socket_addr) = listener.accept()?;
1393 stream2.set_nonblocking(true)?;
1394
1395 let mut alice = Session::new(stream1);
1396 let mut pat = Session::new(stream2);
1397
1398 println!("--- pat reads message, but none has been sent");
1399
1400 assert!(pat.read_message()?.is_none());
1402
1403 println!("--- alice sends no message, but no pending sections so no message sent");
1404
1405 alice.serialize_messages()?;
1407 alice.write_pending_data()?;
1408
1409 println!("--- pat reads message, but none has been sent");
1410
1411 match pat.read_message() {
1413 Ok(Some(msg)) => panic!(
1414 "message should not have been sent: {}",
1415 bson::document::Document::from(msg)
1416 ),
1417 Ok(None) => {}
1418 Err(err) => panic!("{:?}", err),
1419 }
1420
1421 println!("--- pat sends an error message");
1422
1423 const CUSTOM_ERROR: &str = "Custom Error!";
1424
1425 pat.push_outbound_section(Section::Error(ErrorSection {
1426 cookie: Some(42069),
1427 code: ErrorCode::Runtime(1),
1428 message: Some(CUSTOM_ERROR.to_string()),
1429 data: None,
1430 }))?;
1431
1432 pat.serialize_messages()?;
1433 pat.write_pending_data()?;
1434
1435 println!("--- alice reads and verifies message");
1436
1437 let mut alice_read_message: bool = false;
1439 while !alice_read_message {
1440 if let Some(mut msg) = alice.read_message()? {
1442 assert_eq!(msg.sections.len(), 1);
1443 match msg.sections.pop() {
1444 Some(Section::Error(section)) => {
1445 match (section.cookie, section.code, section.message) {
1446 (Some(42069), ErrorCode::Runtime(1), Some(message)) => {
1447 assert_eq!(message, CUSTOM_ERROR);
1448 alice_read_message = true;
1449 }
1450 (cookie, code, message) => panic!(
1451 "unexpected error section: cookie: {:?}, code: {:?}, message: {:?}",
1452 cookie, code, message
1453 ),
1454 };
1455 }
1456 Some(_) => panic!("was expecting an Error section"),
1457 None => panic!("we should have a message"),
1458 }
1459 }
1460 }
1461
1462 println!("--- alice sends multi-section message");
1463
1464 alice.push_outbound_section(Section::Error(ErrorSection {
1465 cookie: Some(42069),
1466 code: ErrorCode::Runtime(2),
1467 message: Some(CUSTOM_ERROR.to_string()),
1468 data: None,
1469 }))?;
1470 alice.push_outbound_section(Section::Request(RequestSection {
1471 cookie: None,
1472 namespace: "std".to_string(),
1473 function: "print".to_string(),
1474 version: 0,
1475 arguments: doc! {"message": "hello!"},
1476 }))?;
1477 alice.push_outbound_section(Section::Response(ResponseSection {
1478 cookie: 123456,
1479 state: RequestState::Pending,
1480 result: None,
1481 }))?;
1482
1483 alice.serialize_messages()?;
1485 alice.write_pending_data()?;
1486
1487 println!("--- pat reads and verifies multi-section message");
1488
1489 let mut pat_read_message: bool = false;
1491 while !pat_read_message {
1492 if let Some(msg) = pat.read_message()? {
1493 assert_eq!(msg.sections.len(), 3);
1494 for section in msg.sections.iter() {
1495 match section {
1496 Section::Error(section) => {
1497 assert_eq!(section.cookie, Some(42069));
1498 assert_eq!(section.code, ErrorCode::Runtime(2));
1499 assert_eq!(section.message, Some(CUSTOM_ERROR.to_string()));
1500 assert_eq!(section.data, None);
1501 }
1502 Section::Request(section) => {
1503 assert_eq!(section.cookie, None);
1504 assert_eq!(section.namespace, "std");
1505 assert_eq!(section.function, "print");
1506 assert_eq!(section.version, 0i32);
1507 }
1508 Section::Response(section) => {
1509 assert_eq!(section.cookie, 123456);
1510 assert_eq!(section.state, RequestState::Pending);
1511 assert_eq!(section.result, None);
1512 }
1513 }
1514 }
1515 pat_read_message = true;
1516 }
1517 }
1518
1519 Ok(())
1520}
1521
1522#[cfg(test)]
1523struct TestApiSet {
1524 call_count: usize,
1525}
1526
1527#[cfg(test)]
1528impl ApiSet for TestApiSet {
1529 fn namespace(&self) -> &str {
1530 "namespace"
1531 }
1532
1533 fn exec_function(
1534 &mut self,
1535 name: &str,
1536 version: i32,
1537 _args: bson::document::Document,
1538 _request_section: Option<RequestCookie>,
1539 ) -> Option<Result<Option<bson::Bson>, ErrorCode>> {
1540 match (name, version) {
1541 ("function", 0) => {
1542 println!("--- namespace::function_0() called");
1543 self.call_count += 1;
1544 }
1545 _ => (),
1546 }
1547 Some(Ok(None))
1548 }
1549}
1550
1551#[test]
1552fn test_honk_timeout() -> anyhow::Result<()> {
1553 let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0u16));
1554 let listener = TcpListener::bind(socket_addr)?;
1555 let socket_addr = listener.local_addr()?;
1556
1557 let alice_stream = TcpStream::connect(socket_addr)?;
1558 alice_stream.set_nonblocking(true)?;
1559 alice_stream.set_nodelay(true)?;
1560 println!("--- alice peer_addr: {}", alice_stream.peer_addr()?);
1561 let (pat_stream, _socket_addr) = listener.accept()?;
1562 pat_stream.set_nonblocking(true)?;
1563 pat_stream.set_nodelay(true)?;
1564
1565 let mut alice = Session::new(alice_stream);
1566 let mut alice_apiset = TestApiSet { call_count: 0usize };
1567 let mut pat = Session::new(pat_stream);
1568
1569 let start = std::time::Instant::now();
1570
1571 println!(
1572 "--- {:?} alice set max_wait_time to 3 seconds",
1573 std::time::Instant::now().duration_since(start)
1574 );
1575 alice.update(None)?;
1576 alice.set_max_wait_time(std::time::Duration::from_secs(3));
1577 alice.update(None)?;
1578
1579 println!(
1581 "--- {:?} sleep 2 seconds",
1582 std::time::Instant::now().duration_since(start)
1583 );
1584 std::thread::sleep(std::time::Duration::from_secs(2));
1585
1586 println!(
1587 "--- {:?} pat calls namespace::function_0()",
1588 std::time::Instant::now().duration_since(start)
1589 );
1590 pat.client_call("namespace", "function", 0, doc! {})?;
1591 while alice_apiset.call_count != 1 {
1592 pat.update(None)?;
1593 alice.update(Some(&mut [&mut alice_apiset]))?;
1594 }
1595
1596 println!(
1598 "--- {:?} sleep 2 seconds",
1599 std::time::Instant::now().duration_since(start)
1600 );
1601 std::thread::sleep(std::time::Duration::from_secs(2));
1602 pat.update(None)?;
1603 alice.update(None)?;
1604
1605 println!(
1606 "--- {:?} pat calls namespace::function_0()",
1607 std::time::Instant::now().duration_since(start)
1608 );
1609 pat.client_call("namespace", "function", 0, doc! {})?;
1610 while alice_apiset.call_count != 2 {
1611 pat.update(None)?;
1612 alice.update(Some(&mut [&mut alice_apiset]))?;
1613 }
1614
1615 println!(
1617 "--- {:?} sleep 4 seconds",
1618 std::time::Instant::now().duration_since(start)
1619 );
1620 std::thread::sleep(std::time::Duration::from_secs(4));
1621
1622 println!(
1623 "--- {:?} pat+alice update",
1624 std::time::Instant::now().duration_since(start)
1625 );
1626 pat.update(None)?;
1627 match alice.update(None) {
1628 Ok(()) => panic!("should have timed out"),
1629 Err(Error::MessageReadTimedOut(duration)) => {
1630 println!("--- expected time out after {:?}", duration)
1631 }
1632 Err(err) => panic!("unexpected error: {:?}", err),
1633 }
1634 Ok(())
1635}